You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

model.py 24 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. Differential privacy model.
  16. """
  17. from easydict import EasyDict as edict
  18. from mindspore.train.model import Model
  19. from mindspore._checkparam import Validator as validator
  20. from mindspore._checkparam import Rel
  21. from mindspore.train import amp
  22. from mindspore.train.amp import _config_level
  23. from mindspore.common import dtype as mstype
  24. from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell
  25. from mindspore.parallel._utils import _get_parallel_mode
  26. from mindspore.train.model import ParallelMode
  27. from mindspore.train.amp import _do_keep_batchnorm_fp32
  28. from mindspore.train.amp import _add_loss_network
  29. from mindspore import context
  30. from mindspore import nn
  31. from mindspore import Tensor
  32. from mindspore.ops import composite as C
  33. from mindspore.ops import operations as P
  34. from mindspore.ops import functional as F
  35. from mindspore.ops.operations import NPUGetFloatStatus
  36. from mindspore.ops.operations import NPUAllocFloatStatus
  37. from mindspore.ops.operations import NPUClearFloatStatus
  38. from mindspore.ops.operations import ReduceSum
  39. from mindspore.ops.operations import LessEqual
  40. from mindspore.ops.operations import ControlDepend
  41. from mindspore.parallel._utils import _get_mirror_mean
  42. from mindspore.parallel._utils import _get_device_num
  43. from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
  44. from mindspore.common.parameter import Parameter
  45. from mindspore.nn.wrap.loss_scale import _grad_overflow
  46. from mindspore.nn import Cell
  47. from mindspore import ParameterTuple
  48. from mindarmour.utils._check_param import check_param_type
  49. from mindarmour.utils._check_param import check_value_positive
  50. from mindarmour.utils._check_param import check_int_positive
  51. GRADIENT_CLIP_TYPE = 1
  52. _grad_scale = C.MultitypeFuncGraph("grad_scale")
  53. _reciprocal = P.Reciprocal()
  54. @_grad_scale.register("Tensor", "Tensor")
  55. def tensor_grad_scale(scale, grad):
  56. """ grad scaling """
  57. return grad * F.cast(_reciprocal(scale), F.dtype(grad))
  58. class DPModel(Model):
  59. """
  60. This class is overload mindspore.train.model.Model.
  61. Args:
  62. micro_batches (int): The number of small batches split from an original batch. Default: 2.
  63. norm_clip (float): Use to clip the bound, if set 1, will retun the original data. Default: 1.0.
  64. mech (Mechanisms): The object can generate the different type of noise. Default: None.
  65. Examples:
  66. >>> class Net(nn.Cell):
  67. >>> def __init__(self):
  68. >>> super(Net, self).__init__()
  69. >>> self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal')
  70. >>> self.bn = nn.BatchNorm2d(64)
  71. >>> self.relu = nn.ReLU()
  72. >>> self.flatten = nn.Flatten()
  73. >>> self.fc = nn.Dense(64*224*224, 12) # padding=0
  74. >>>
  75. >>> def construct(self, x):
  76. >>> x = self.conv(x)
  77. >>> x = self.bn(x)
  78. >>> x = self.relu(x)
  79. >>> x = self.flatten(x)
  80. >>> out = self.fc(x)
  81. >>> return out
  82. >>>
  83. >>> net = Net()
  84. >>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
  85. >>> net_opt = Momentum(params=net.trainable_params(), learning_rate=0.01, momentum=0.9)
  86. >>> mech = MechanismsFactory().create('Gaussian',
  87. >>> norm_bound=args.l2_norm_bound,
  88. >>> initial_noise_multiplier=args.initial_noise_multiplier)
  89. >>> model = DPModel(micro_batches=2,
  90. >>> norm_clip=1.0,
  91. >>> mech=mech,
  92. >>> network=net,
  93. >>> loss_fn=loss,
  94. >>> optimizer=net_opt,
  95. >>> metrics=None)
  96. >>> dataset = get_dataset()
  97. >>> model.train(2, dataset)
  98. """
  99. def __init__(self, micro_batches=2, norm_clip=1.0, mech=None, **kwargs):
  100. if micro_batches:
  101. self._micro_batches = check_int_positive('micro_batches', micro_batches)
  102. else:
  103. self._micro_batches = None
  104. float_norm_clip = check_param_type('l2_norm_clip', norm_clip, float)
  105. self._norm_clip = check_value_positive('l2_norm_clip', float_norm_clip)
  106. if mech is not None and "DPOptimizer" in kwargs['optimizer'].__class__.__name__:
  107. raise ValueError('DPOptimizer is not supported while mech is not None')
  108. if mech is None:
  109. if "DPOptimizer" in kwargs['optimizer'].__class__.__name__:
  110. if context.get_context('mode') != context.PYNATIVE_MODE:
  111. raise ValueError('DPOptimizer just support pynative mode currently.')
  112. else:
  113. raise ValueError('DPModel should set mech or DPOptimizer configure, please refer to example.')
  114. self._mech = mech
  115. super(DPModel, self).__init__(**kwargs)
  116. def _amp_build_train_network(self, network, optimizer, loss_fn=None, level='O0', **kwargs):
  117. """
  118. Build the mixed precision training cell automatically.
  119. Args:
  120. network (Cell): Definition of the network.
  121. loss_fn (Union[None, Cell]): Definition of the loss_fn. If None, the `network` should have the loss inside.
  122. Default: None.
  123. optimizer (Optimizer): Optimizer to update the Parameter.
  124. level (str): Supports [O0, O2]. Default: "O0".
  125. - O0: Do not change.
  126. - O2: Cast network to float16, keep batchnorm and `loss_fn` (if set) run in float32,
  127. using dynamic loss scale.
  128. cast_model_type (:class:`mindspore.dtype`): Supports `mstype.float16` or `mstype.float32`.
  129. If set to `mstype.float16`, use `float16` mode to train. If set, overwrite the level setting.
  130. keep_batchnorm_fp32 (bool): Keep Batchnorm run in `float32`. If set, overwrite the level setting.
  131. loss_scale_manager (Union[None, LossScaleManager]): If None, not scale the loss, or else
  132. scale the loss by LossScaleManager. If set, overwrite the level setting.
  133. """
  134. validator.check_value_type('network', network, nn.Cell, None)
  135. validator.check_value_type('optimizer', optimizer, nn.Optimizer, None)
  136. validator.check('level', level, "", ['O0', 'O2'], Rel.IN, None)
  137. self._check_kwargs(kwargs)
  138. config = dict(_config_level[level], **kwargs)
  139. config = edict(config)
  140. if config.cast_model_type == mstype.float16:
  141. network.to_float(mstype.float16)
  142. if config.keep_batchnorm_fp32:
  143. _do_keep_batchnorm_fp32(network)
  144. if loss_fn:
  145. network = _add_loss_network(network, loss_fn, config.cast_model_type)
  146. if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
  147. network = _VirtualDatasetCell(network)
  148. loss_scale = 1.0
  149. if config.loss_scale_manager is not None:
  150. loss_scale_manager = config.loss_scale_manager
  151. loss_scale = loss_scale_manager.get_loss_scale()
  152. update_cell = loss_scale_manager.get_update_cell()
  153. if update_cell is not None:
  154. # only cpu not support `TrainOneStepWithLossScaleCell` for control flow.
  155. if not context.get_context("enable_ge") and context.get_context("device_target") == "CPU":
  156. raise ValueError("Only `loss_scale_manager=None` and "
  157. "`loss_scale_manager=FixedLossScaleManager(drop_overflow_update=False)`"
  158. "are supported in current version. If you use `O2` option, please"
  159. "use `loss_scale_manager=None` or `FixedLossScaleManager`")
  160. network = _TrainOneStepWithLossScaleCell(network,
  161. optimizer,
  162. scale_update_cell=update_cell,
  163. micro_batches=self._micro_batches,
  164. l2_norm_clip=self._norm_clip,
  165. mech=self._mech).set_train()
  166. return network
  167. network = _TrainOneStepCell(network,
  168. optimizer,
  169. loss_scale,
  170. micro_batches=self._micro_batches,
  171. l2_norm_clip=self._norm_clip,
  172. mech=self._mech).set_train()
  173. return network
  174. def _build_train_network(self):
  175. """Build train network"""
  176. network = self._network
  177. if self._micro_batches:
  178. if self._optimizer:
  179. if self._loss_scale_manager_set:
  180. network = self._amp_build_train_network(network,
  181. self._optimizer,
  182. self._loss_fn,
  183. level=self._amp_level,
  184. loss_scale_manager=self._loss_scale_manager,
  185. keep_batchnorm_fp32=self._keep_bn_fp32)
  186. else:
  187. network = self._amp_build_train_network(network,
  188. self._optimizer,
  189. self._loss_fn,
  190. level=self._amp_level,
  191. keep_batchnorm_fp32=self._keep_bn_fp32)
  192. elif self._loss_fn:
  193. network = nn.WithLossCell(network, self._loss_fn)
  194. else:
  195. if self._optimizer:
  196. if self._loss_scale_manager_set:
  197. network = amp.build_train_network(network,
  198. self._optimizer,
  199. self._loss_fn,
  200. level=self._amp_level,
  201. loss_scale_manager=self._loss_scale_manager,
  202. keep_batchnorm_fp32=self._keep_bn_fp32)
  203. else:
  204. network = amp.build_train_network(network,
  205. self._optimizer,
  206. self._loss_fn,
  207. level=self._amp_level,
  208. keep_batchnorm_fp32=self._keep_bn_fp32)
  209. elif self._loss_fn:
  210. network = nn.WithLossCell(network, self._loss_fn)
  211. if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
  212. network.set_auto_parallel()
  213. return network
  214. class _ClipGradients(nn.Cell):
  215. """
  216. Clip gradients.
  217. Inputs:
  218. grads (tuple[Tensor]): Gradients.
  219. clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'.
  220. clip_value (float): Specifies how much to clip.
  221. Outputs:
  222. tuple[Tensor], clipped gradients.
  223. """
  224. def __init__(self):
  225. super(_ClipGradients, self).__init__()
  226. self.clip_by_norm = nn.ClipByNorm()
  227. self.dtype = P.DType()
  228. def construct(self, grads, clip_type, clip_value):
  229. """
  230. construct a compute flow.
  231. """
  232. # pylint: disable=consider-using-in
  233. if clip_type != 0 and clip_type != 1:
  234. return grads
  235. new_grads = ()
  236. for grad in grads:
  237. if clip_type == 0:
  238. t = C.clip_by_value(grad, F.tuple_to_array((-clip_value,)),
  239. F.tuple_to_array((clip_value,)))
  240. else:
  241. t = self.clip_by_norm(grad, F.tuple_to_array((clip_value,)))
  242. new_grads = new_grads + (t,)
  243. return new_grads
  244. class _TupleAdd(nn.Cell):
  245. def __init__(self):
  246. super(_TupleAdd, self).__init__()
  247. self.add = P.TensorAdd()
  248. self.hyper_map = C.HyperMap()
  249. def construct(self, input1, input2):
  250. """Add two tuple of data."""
  251. out = self.hyper_map(self.add, input1, input2)
  252. return out
  253. class _TrainOneStepWithLossScaleCell(Cell):
  254. r"""
  255. Network training with loss scaling.
  256. This is a training step with loss scaling. It takes a network, an optimizer and possibly a scale update
  257. Cell as args. The loss scale value can be updated in both host side or device side. The
  258. TrainOneStepWithLossScaleCell will be compiled to be graph which takes `data`, `label`, `sens` as input
  259. data. The `sens` is acting as loss scaling value. If you want to update it on host side, the value should
  260. be provided. If `sens` is not given, the loss scale update logic should be provied by `scale_update_cell`.
  261. If `scale_update_cell` is not None and `sens` is provided, the `scale_update_cell` will be ignored.
  262. Args:
  263. network (Cell): The training network.
  264. optimizer (Cell): Optimizer for updating the weights.
  265. scale_update_cell(Cell): The loss scaling update logic cell. Default: None.
  266. micro_batches (int): The number of small batches split from an original batch. Default: None.
  267. l2_norm_clip (float): Use to clip the bound, if set 1, will return the original data. Default: 1.0.
  268. mech (Mechanisms): The object can generate the different type of noise. Default: None.
  269. Inputs:
  270. - **inputs** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
  271. - **label** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
  272. - **scaling_sens** (Tensor) - Tensor of shape :math:`()`.
  273. Outputs:
  274. Tuple of 3 Tensor, the loss, overflow flag and current loss scaling value.
  275. - **loss** (Tensor) - Tensor with shape :math:`()`.
  276. - **overflow** (Tensor) - Tensor with shape :math:`()`, type is bool.
  277. - **loss_scale** (Tensor) - Tensor with shape :math:`()`.
  278. """
  279. def __init__(self, network, optimizer, scale_update_cell=None, micro_batches=None, l2_norm_clip=1.0, mech=None):
  280. super(_TrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
  281. self.network = network
  282. self.network.set_grad()
  283. self.network.add_flags(defer_inline=True)
  284. self.weights = ParameterTuple(network.trainable_params())
  285. self.optimizer = optimizer
  286. self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
  287. self.hyper_map = C.HyperMap()
  288. if context.get_context("device_target") == "GPU":
  289. self.gpu_target = True
  290. self.float_status = P.FloatStatus()
  291. self.addn = P.AddN()
  292. self.reshape = P.Reshape()
  293. else:
  294. self.gpu_target = False
  295. self.alloc_status = NPUAllocFloatStatus()
  296. self.get_status = NPUGetFloatStatus()
  297. self.clear_status = NPUClearFloatStatus()
  298. self.reduce_sum = ReduceSum(keep_dims=False)
  299. self.base = Tensor(1, mstype.float32)
  300. self.less_equal = LessEqual()
  301. self.depend_parameter_use = ControlDepend(depend_mode=1)
  302. self.allreduce = P.AllReduce()
  303. self.parallel_mode = _get_parallel_mode()
  304. self.grad_reducer = F.identity
  305. self.reducer_flag = self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]
  306. if self.reducer_flag:
  307. mean = _get_mirror_mean()
  308. degree = _get_device_num()
  309. self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
  310. self.is_distributed = self.parallel_mode != ParallelMode.STAND_ALONE
  311. self.loss_scale = None
  312. self.loss_scaling_manager = scale_update_cell
  313. if scale_update_cell:
  314. self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32),
  315. name="loss_scale")
  316. self.add_flags(has_effect=True)
  317. # dp params
  318. self._micro_batches = micro_batches
  319. float_norm_clip = check_param_type('l2_norm_clip', l2_norm_clip, float)
  320. self._l2_norm = check_value_positive('l2_norm_clip', float_norm_clip)
  321. self._split = P.Split(0, self._micro_batches)
  322. self._clip_by_global_norm = _ClipGradients()
  323. self._mech = mech
  324. self._tuple_add = _TupleAdd()
  325. self._hyper_map = C.HyperMap()
  326. self._micro_float = Tensor(micro_batches, mstype.float32)
  327. def construct(self, data, label, sens=None):
  328. """
  329. construct a compute flow.
  330. """
  331. init = False
  332. if not self.gpu_target:
  333. # init overflow buffer
  334. init = self.alloc_status()
  335. # clear overflow buffer
  336. self.clear_status(init)
  337. if sens is None:
  338. scaling_sens = self.loss_scale
  339. else:
  340. scaling_sens = sens
  341. # DP clip
  342. weights = self.weights
  343. record_datas = self._split(data)
  344. record_labels = self._split(label)
  345. # first index
  346. loss = self.network(record_datas[0], record_labels[0])
  347. scaling_sens_filled = C.ones_like(loss)*F.cast(scaling_sens, F.dtype(loss))
  348. record_grad = self.grad(self.network, weights)(record_datas[0], record_labels[0], scaling_sens_filled)
  349. record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, self._l2_norm)
  350. grads = record_grad
  351. total_loss = loss
  352. for i in range(1, self._micro_batches):
  353. loss = self.network(record_datas[i], record_labels[i])
  354. scaling_sens_filled = C.ones_like(loss)*F.cast(scaling_sens, F.dtype(loss))
  355. record_grad = self.grad(self.network, weights)(record_datas[i], record_labels[i], scaling_sens_filled)
  356. record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, self._l2_norm)
  357. grads = self._tuple_add(grads, record_grad)
  358. total_loss = P.TensorAdd()(total_loss, loss)
  359. loss = P.Div()(total_loss, self._micro_float)
  360. if self._mech is not None:
  361. grad_noise = self._hyper_map(self._mech, grads)
  362. grads = self._tuple_add(grads, grad_noise)
  363. grads = self._hyper_map(F.partial(_grad_scale, self._micro_float), grads)
  364. grads = self.hyper_map(F.partial(_grad_scale, scaling_sens), grads)
  365. # apply grad reducer on grads
  366. grads = self.grad_reducer(grads)
  367. # get the overflow buffer
  368. if not self.gpu_target:
  369. self.get_status(init)
  370. # sum overflow buffer elements, 0:not overflow , >0:overflow
  371. flag_sum = self.reduce_sum(init, (0,))
  372. else:
  373. flag_sum = self.hyper_map(F.partial(_grad_overflow), grads)
  374. flag_sum = self.addn(flag_sum)
  375. # convert flag_sum to scalar
  376. flag_sum = self.reshape(flag_sum, (()))
  377. if self.is_distributed:
  378. # sum overflow flag over devices
  379. flag_reduce = self.allreduce(flag_sum)
  380. cond = self.less_equal(self.base, flag_reduce)
  381. else:
  382. cond = self.less_equal(self.base, flag_sum)
  383. overflow = cond
  384. if sens is None:
  385. overflow = self.loss_scaling_manager(self.loss_scale, cond)
  386. # if there is no overflow, do optimize
  387. if overflow:
  388. opt = False
  389. else:
  390. opt = self.optimizer(grads)
  391. ret = (loss, cond, scaling_sens)
  392. return F.depend(ret, opt)
  393. class _TrainOneStepCell(Cell):
  394. r"""
  395. Network training package class.
  396. Wraps the network with an optimizer. The resulting Cell be trained with input data and label.
  397. Backward graph will be created in the construct function to do parameter updating. Different
  398. parallel modes are available to run the training.
  399. Args:
  400. network (Cell): The training network.
  401. optimizer (Cell): Optimizer for updating the weights.
  402. sens (Number): The scaling number to be filled as the input of back propagation. Default value is 1.0.
  403. micro_batches (int): The number of small batches split from an original batch. Default: None.
  404. l2_norm_clip (float): Use to clip the bound, if set 1, will return the original data. Default: 1.0.
  405. mech (Mechanisms): The object can generate the different type of noise. Default: None.
  406. Inputs:
  407. - **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
  408. - **label** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
  409. Outputs:
  410. Tensor, a scalar Tensor with shape :math:`()`.
  411. """
  412. def __init__(self, network, optimizer, sens=1.0, micro_batches=None, l2_norm_clip=1.0, mech=None):
  413. super(_TrainOneStepCell, self).__init__(auto_prefix=False)
  414. self.network = network
  415. self.network.set_grad()
  416. self.network.add_flags(defer_inline=True)
  417. self.weights = optimizer.parameters
  418. self.optimizer = optimizer
  419. self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
  420. self.sens = sens
  421. self.reducer_flag = False
  422. self.grad_reducer = None
  423. parallel_mode = _get_parallel_mode()
  424. if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
  425. self.reducer_flag = True
  426. if self.reducer_flag:
  427. mean = _get_mirror_mean()
  428. degree = _get_device_num()
  429. self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
  430. # dp params
  431. self._micro_batches = micro_batches
  432. float_norm_clip = check_param_type('l2_norm_clip', l2_norm_clip, float)
  433. self._l2_norm = check_value_positive('l2_norm_clip', float_norm_clip)
  434. self._split = P.Split(0, self._micro_batches)
  435. self._clip_by_global_norm = _ClipGradients()
  436. self._mech = mech
  437. self._tuple_add = _TupleAdd()
  438. self._hyper_map = C.HyperMap()
  439. self._micro_float = Tensor(micro_batches, mstype.float32)
  440. def construct(self, data, label):
  441. """
  442. construct a compute flow.
  443. """
  444. weights = self.weights
  445. record_datas = self._split(data)
  446. record_labels = self._split(label)
  447. loss = self.network(record_datas[0], record_labels[0])
  448. sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
  449. record_grad = self.grad(self.network, weights)(record_datas[0], record_labels[0], sens)
  450. record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, self._l2_norm)
  451. grads = record_grad
  452. total_loss = loss
  453. for i in range(1, self._micro_batches):
  454. loss = self.network(record_datas[i], record_labels[i])
  455. sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
  456. record_grad = self.grad(self.network, weights)(record_datas[i], record_labels[i], sens)
  457. record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, self._l2_norm)
  458. grads = self._tuple_add(grads, record_grad)
  459. total_loss = P.TensorAdd()(total_loss, loss)
  460. loss = P.Div()(total_loss, self._micro_float)
  461. if self._mech is not None:
  462. grad_noise = self._hyper_map(self._mech, grads)
  463. grads = self._tuple_add(grads, grad_noise)
  464. grads = self._hyper_map(F.partial(_grad_scale, self._micro_float), grads)
  465. if self.reducer_flag:
  466. # apply grad reducer on grads
  467. grads = self.grad_reducer(grads)
  468. return F.depend(loss, self.optimizer(grads))

MindArmour关注AI的安全和隐私问题。致力于增强模型的安全可信、保护用户的数据隐私。主要包含3个模块:对抗样本鲁棒性模块、Fuzz Testing模块、隐私保护与评估模块。