You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

model.py 33 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. Differential privacy model.
  16. """
  17. from easydict import EasyDict as edict
  18. from mindspore.train.model import Model
  19. from mindspore._checkparam import Validator as validator
  20. from mindspore._checkparam import Rel
  21. from mindspore.train import amp
  22. from mindspore.train.amp import _config_level
  23. from mindspore.common import dtype as mstype
  24. from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell
  25. from mindspore.parallel._utils import _get_parallel_mode
  26. from mindspore.train.model import ParallelMode
  27. from mindspore.train.amp import _do_keep_batchnorm_fp32
  28. from mindspore.train.amp import _add_loss_network
  29. from mindspore import context
  30. from mindspore import nn
  31. from mindspore import Tensor
  32. from mindspore.ops import composite as C
  33. from mindspore.ops import operations as P
  34. from mindspore.ops import functional as F
  35. from mindspore.ops.operations import NPUGetFloatStatus
  36. from mindspore.ops.operations import NPUAllocFloatStatus
  37. from mindspore.ops.operations import NPUClearFloatStatus
  38. from mindspore.ops.operations import ReduceSum
  39. from mindspore.ops.operations import LessEqual
  40. from mindspore.ops.operations import ControlDepend
  41. from mindspore.parallel._utils import _get_gradients_mean
  42. from mindspore.parallel._utils import _get_device_num
  43. from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
  44. from mindspore.common.parameter import Parameter
  45. from mindspore.nn.wrap.loss_scale import _grad_overflow
  46. from mindspore.nn import Cell
  47. from mindspore import ParameterTuple
  48. from mindarmour.utils.logger import LogUtil
  49. from mindarmour.utils._check_param import check_value_positive, check_param_type
  50. from mindarmour.utils._check_param import check_int_positive
  51. from ..mechanisms.mechanisms import _MechanismsParamsUpdater
  52. LOGGER = LogUtil.get_instance()
  53. TAG = 'DP model'
  54. GRADIENT_CLIP_TYPE = 1
  55. _grad_scale = C.MultitypeFuncGraph("grad_scale")
  56. _reciprocal = P.Reciprocal()
  57. @_grad_scale.register("Tensor", "Tensor")
  58. def tensor_grad_scale(scale, grad):
  59. """ grad scaling """
  60. return grad*F.cast(_reciprocal(scale), F.dtype(grad))
  61. class DPModel(Model):
  62. """
  63. This class is overload mindspore.train.model.Model.
  64. Args:
  65. micro_batches (int): The number of small batches split from an original
  66. batch. Default: 2.
  67. norm_bound (float): Use to clip the bound, if set 1, will return the
  68. original data. Default: 1.0.
  69. noise_mech (Mechanisms): The object can generate the different type of
  70. noise. Default: None.
  71. clip_mech (Mechanisms): The object is used to update the adaptive clip.
  72. Default: None.
  73. Raises:
  74. ValueError: If DPOptimizer and noise_mecn are both None or not None.
  75. ValueError: If noise_mech or DPOtimizer's mech method is adaptive while clip_mech is not None.
  76. Examples:
  77. >>> norm_bound = 1.0
  78. >>> initial_noise_multiplier = 0.01
  79. >>> network = LeNet5()
  80. >>> batch_size = 32
  81. >>> batches = 128
  82. >>> epochs = 1
  83. >>> micro_batches = 2
  84. >>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
  85. >>> factory_opt = DPOptimizerClassFactory(micro_batches=micro_batches)
  86. >>> factory_opt.set_mechanisms('Gaussian',
  87. >>> norm_bound=norm_bound,
  88. >>> initial_noise_multiplier=initial_noise_multiplier)
  89. >>> net_opt = factory_opt.create('Momentum')(network.trainable_params(),
  90. >>> learning_rate=0.1, momentum=0.9)
  91. >>> clip_mech = ClipMechanismsFactory().create('Gaussian',
  92. >>> decay_policy='Linear',
  93. >>> learning_rate=0.01,
  94. >>> target_unclipped_quantile=0.9,
  95. >>> fraction_stddev=0.01)
  96. >>> model = DPModel(micro_batches=micro_batches,
  97. >>> norm_bound=norm_bound,
  98. >>> clip_mech=clip_mech,
  99. >>> noise_mech=None,
  100. >>> network=network,
  101. >>> loss_fn=loss,
  102. >>> optimizer=net_opt,
  103. >>> metrics=None)
  104. >>> ms_ds = ds.GeneratorDataset(dataset_generator,
  105. >>> ['data', 'label'])
  106. >>> model.train(epochs, ms_ds, dataset_sink_mode=False)
  107. """
  108. def __init__(self, micro_batches=2, norm_bound=1.0, noise_mech=None,
  109. clip_mech=None, **kwargs):
  110. if micro_batches:
  111. self._micro_batches = check_int_positive('micro_batches',
  112. micro_batches)
  113. else:
  114. self._micro_batches = None
  115. norm_bound = check_param_type('norm_bound', norm_bound, float)
  116. norm_bound = check_value_positive('norm_bound', norm_bound)
  117. norm_bound = Tensor(norm_bound, mstype.float32)
  118. self._norm_bound = Parameter(norm_bound, 'norm_bound')
  119. opt = kwargs['optimizer']
  120. opt_name = opt.__class__.__name__
  121. # Check whether noise_mech and DPOptimizer are both None or not None, if so, raise ValueError.
  122. # And check whether noise_mech or DPOtimizer's mech method is adaptive while clip_mech is not None,
  123. # if so, raise ValuerError too.
  124. if noise_mech is not None and "DPOptimizer" in opt_name:
  125. msg = 'DPOptimizer is not supported while noise_mech is not None'
  126. LOGGER.error(TAG, msg)
  127. raise ValueError(msg)
  128. if noise_mech is None:
  129. if "DPOptimizer" in opt_name:
  130. if 'Ada' in opt._mech.__class__.__name__ and clip_mech is not None:
  131. msg = "When DPOptimizer's mech method is adaptive, clip_mech must be None."
  132. LOGGER.error(TAG, msg)
  133. raise ValueError(msg)
  134. else:
  135. msg = 'DPModel should set noise_mech or DPOptimizer configure, ' \
  136. 'please refer to example.'
  137. LOGGER.error(TAG, msg)
  138. raise ValueError(msg)
  139. self._noise_mech = noise_mech
  140. if noise_mech is not None:
  141. if 'Ada' in noise_mech.__class__.__name__ and clip_mech is not None:
  142. msg = 'When noise_mech is Adaptive, clip_mech must be None.'
  143. LOGGER.error(TAG, msg)
  144. raise ValueError(msg)
  145. if clip_mech is None or isinstance(clip_mech, Cell):
  146. self._clip_mech = clip_mech
  147. super(DPModel, self).__init__(**kwargs)
  148. def _amp_build_train_network(self, network, optimizer, loss_fn=None,
  149. level='O0', **kwargs):
  150. """
  151. Build the mixed precision training cell automatically.
  152. Args:
  153. network (Cell): Definition of the network.
  154. loss_fn (Union[None, Cell]): Definition of the loss_fn. If None,
  155. the `network` should have the loss inside. Default: None.
  156. optimizer (Optimizer): Optimizer to update the Parameter.
  157. level (str): Supports [O0, O2]. Default: "O0".
  158. - O0: Do not change.
  159. - O2: Cast network to float16, keep batchnorm and `loss_fn`
  160. (if set) run in float32, using dynamic loss scale.
  161. cast_model_type (:class:`mindspore.dtype`): Supports `mstype.float16`
  162. or `mstype.float32`. If set to `mstype.float16`, use `float16`
  163. mode to train. If set, overwrite the level setting.
  164. keep_batchnorm_fp32 (bool): Keep Batchnorm run in `float32`. If set,
  165. overwrite the level setting.
  166. loss_scale_manager (Union[None, LossScaleManager]): If None, not
  167. scale the loss, or else scale the loss by LossScaleManager.
  168. If set, overwrite the level setting.
  169. """
  170. validator.check_value_type('network', network, nn.Cell, None)
  171. validator.check_value_type('optimizer', optimizer, nn.Optimizer, None)
  172. validator.check('level', level, "", ['O0', 'O2'], Rel.IN, None)
  173. self._check_kwargs(kwargs)
  174. config = dict(_config_level[level], **kwargs)
  175. config = edict(config)
  176. if config.cast_model_type == mstype.float16:
  177. network.to_float(mstype.float16)
  178. if config.keep_batchnorm_fp32:
  179. _do_keep_batchnorm_fp32(network)
  180. if loss_fn:
  181. network = _add_loss_network(network, loss_fn,
  182. config.cast_model_type)
  183. if _get_parallel_mode() in (
  184. ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
  185. network = _VirtualDatasetCell(network)
  186. loss_scale = 1.0
  187. if config.loss_scale_manager is not None:
  188. loss_scale_manager = config.loss_scale_manager
  189. loss_scale = loss_scale_manager.get_loss_scale()
  190. update_cell = loss_scale_manager.get_update_cell()
  191. if update_cell is not None:
  192. # only cpu not support `TrainOneStepWithLossScaleCell` for control flow.
  193. if not context.get_context("enable_ge") and context.get_context(
  194. "device_target") == "CPU":
  195. msg = "Only `loss_scale_manager=None` and " \
  196. "`loss_scale_manager=FixedLossScaleManager(drop_overflow" \
  197. "_update=False)` are supported in current version. " \
  198. "If you use `O2` option, please use " \
  199. "`loss_scale_manager=None` or `FixedLossScaleManager`"
  200. LOGGER.error(TAG, msg)
  201. raise ValueError(msg)
  202. network = _TrainOneStepWithLossScaleCell(network,
  203. optimizer,
  204. scale_update_cell=update_cell,
  205. micro_batches=self._micro_batches,
  206. norm_bound=self._norm_bound,
  207. clip_mech=self._clip_mech,
  208. noise_mech=self._noise_mech).set_train()
  209. return network
  210. network = _TrainOneStepCell(network,
  211. optimizer,
  212. self._norm_bound,
  213. loss_scale,
  214. micro_batches=self._micro_batches,
  215. clip_mech=self._clip_mech,
  216. noise_mech=self._noise_mech).set_train()
  217. return network
  218. def _build_train_network(self):
  219. """Build train network"""
  220. network = self._network
  221. if self._micro_batches:
  222. if self._optimizer:
  223. if self._loss_scale_manager_set:
  224. network = self._amp_build_train_network(network,
  225. self._optimizer,
  226. self._loss_fn,
  227. level=self._amp_level,
  228. loss_scale_manager=self._loss_scale_manager,
  229. keep_batchnorm_fp32=self._keep_bn_fp32)
  230. else:
  231. network = self._amp_build_train_network(network,
  232. self._optimizer,
  233. self._loss_fn,
  234. level=self._amp_level,
  235. keep_batchnorm_fp32=self._keep_bn_fp32)
  236. elif self._loss_fn:
  237. network = nn.WithLossCell(network, self._loss_fn)
  238. else:
  239. if self._optimizer:
  240. if self._loss_scale_manager_set:
  241. network = amp.build_train_network(network,
  242. self._optimizer,
  243. self._loss_fn,
  244. level=self._amp_level,
  245. loss_scale_manager=self._loss_scale_manager,
  246. keep_batchnorm_fp32=self._keep_bn_fp32)
  247. else:
  248. network = amp.build_train_network(network,
  249. self._optimizer,
  250. self._loss_fn,
  251. level=self._amp_level,
  252. keep_batchnorm_fp32=self._keep_bn_fp32)
  253. elif self._loss_fn:
  254. network = nn.WithLossCell(network, self._loss_fn)
  255. if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL,
  256. ParallelMode.AUTO_PARALLEL):
  257. network.set_auto_parallel()
  258. return network
  259. class _ClipGradients(nn.Cell):
  260. """
  261. Clip gradients.
  262. Inputs:
  263. grads (tuple[Tensor]): Gradients.
  264. clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'.
  265. clip_value (float): Specifies how much to clip.
  266. Outputs:
  267. tuple[Tensor], clipped gradients.
  268. """
  269. def __init__(self):
  270. super(_ClipGradients, self).__init__()
  271. self.clip_by_norm = nn.ClipByNorm()
  272. self.dtype = P.DType()
  273. def construct(self, grads, clip_type, clip_value):
  274. """
  275. construct a compute flow.
  276. """
  277. if clip_type not in (0, 1):
  278. return grads
  279. new_grads = ()
  280. for grad in grads:
  281. if clip_type == 0:
  282. norm = C.clip_by_value(grad, -clip_value, clip_value)
  283. else:
  284. norm = self.clip_by_norm(grad, clip_value)
  285. new_grads = new_grads + (norm,)
  286. return new_grads
  287. class _TupleAdd(nn.Cell):
  288. def __init__(self):
  289. super(_TupleAdd, self).__init__()
  290. self.add = P.Add()
  291. self.hyper_map = C.HyperMap()
  292. def construct(self, input1, input2):
  293. """Add two tuple of data."""
  294. out = self.hyper_map(self.add, input1, input2)
  295. return out
  296. class _TrainOneStepWithLossScaleCell(Cell):
  297. r"""
  298. Network training with loss scaling.
  299. This is a training step with loss scaling. It takes a network, an optimizer
  300. and possibly a scale update Cell as args. The loss scale value can be
  301. updated in both host side or device side. The TrainOneStepWithLossScaleCell
  302. will be compiled to be graph which takes `data`, `label`, `sens` as input
  303. data. The `sens` is acting as loss scaling value. If you want to update it
  304. on host side, the value should be provided. If `sens` is not given, the loss
  305. scale update logic should be provied by `scale_update_cell`. If
  306. `scale_update_cell` is not None and `sens` is provided, the
  307. `scale_update_cell` will be ignored.
  308. Args:
  309. network (Cell): The training network.
  310. optimizer (Cell): Optimizer for updating the weights.
  311. scale_update_cell(Cell): The loss scaling update logic cell.
  312. Default: None.
  313. micro_batches (int): The number of small batches split from an original
  314. batch. Default: None.
  315. norm_bound (Tensor): Use to clip the bound, if set 1, will return the
  316. original data. Default: 1.0.
  317. noise_mech (Mechanisms): The object can generate the different type of
  318. noise. Default: None.
  319. Inputs:
  320. - **inputs** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
  321. - **label** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
  322. - **scaling_sens** (Tensor) - Tensor of shape :math:`()`.
  323. Outputs:
  324. Tuple of 3 Tensor, the loss, overflow flag and current loss scaling value.
  325. - **loss** (Tensor) - Tensor with shape :math:`()`.
  326. - **overflow** (Tensor) - Tensor with shape :math:`()`, type is bool.
  327. - **loss_scale** (Tensor) - Tensor with shape :math:`()`.
  328. """
  329. def __init__(self, network, optimizer, scale_update_cell=None,
  330. micro_batches=None, norm_bound=1.0, noise_mech=None,
  331. clip_mech=None):
  332. super(_TrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
  333. self.network = network
  334. self.network.set_grad()
  335. self.network.add_flags(defer_inline=True)
  336. self.weights = ParameterTuple(network.trainable_params())
  337. self.optimizer = optimizer
  338. self.grad = C.GradOperation(get_by_list=True, sens_param=True)
  339. self.hyper_map = C.HyperMap()
  340. if context.get_context("device_target") == "GPU":
  341. self.gpu_target = True
  342. self.float_status = P.FloatStatus()
  343. self.addn = P.AddN()
  344. self.reshape = P.Reshape()
  345. else:
  346. self.gpu_target = False
  347. self.alloc_status = NPUAllocFloatStatus()
  348. self.get_status = NPUGetFloatStatus()
  349. self.clear_status = NPUClearFloatStatus()
  350. self.reduce_sum = ReduceSum(keep_dims=False)
  351. self.base = Tensor(1, mstype.float32)
  352. self.less_equal = LessEqual()
  353. self.depend_parameter_use = ControlDepend(depend_mode=1)
  354. self.allreduce = P.AllReduce()
  355. self.parallel_mode = _get_parallel_mode()
  356. self.grad_reducer = F.identity
  357. self.reducer_flag = self.parallel_mode in [ParallelMode.DATA_PARALLEL,
  358. ParallelMode.HYBRID_PARALLEL]
  359. if self.reducer_flag:
  360. mean = _get_gradients_mean()
  361. degree = _get_device_num()
  362. self.grad_reducer = DistributedGradReducer(optimizer.parameters,
  363. mean, degree)
  364. self.is_distributed = self.parallel_mode != ParallelMode.STAND_ALONE
  365. self.loss_scale = None
  366. self.loss_scaling_manager = scale_update_cell
  367. if scale_update_cell:
  368. self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32),
  369. name="loss_scale")
  370. self.add_flags(has_effect=True)
  371. # dp params
  372. self._micro_batches = micro_batches
  373. self._norm_bound = norm_bound
  374. self._split = P.Split(0, self._micro_batches)
  375. self._clip_by_global_norm = _ClipGradients()
  376. self._noise_mech = noise_mech
  377. self._clip_mech = clip_mech
  378. self._add = P.Add()
  379. self._norm = nn.Norm()
  380. self._tuple_add = _TupleAdd()
  381. self._hyper_map = C.HyperMap()
  382. self._micro_float = Tensor(micro_batches, mstype.float32)
  383. self._zero = Tensor(0, mstype.float32)
  384. self._assign = P.Assign()
  385. self._div = P.Div()
  386. self._sqrt = P.Sqrt()
  387. self._reduce_sum = P.ReduceSum()
  388. self._square_all = P.Square()
  389. self._less = P.Less()
  390. self._cast = P.Cast()
  391. self._noise_mech_param_updater = None
  392. if self._noise_mech is not None and self._noise_mech._decay_policy is not None:
  393. self._noise_mech_param_updater = _MechanismsParamsUpdater(
  394. decay_policy=self._noise_mech._decay_policy,
  395. decay_rate=self._noise_mech._noise_decay_rate,
  396. cur_noise_multiplier=
  397. self._noise_mech._noise_multiplier,
  398. init_noise_multiplier=
  399. self._noise_mech._initial_noise_multiplier)
  400. def construct(self, data, label, sens=None):
  401. """
  402. construct a compute flow.
  403. """
  404. init = False
  405. if not self.gpu_target:
  406. # init overflow buffer
  407. init = self.alloc_status()
  408. # clear overflow buffer
  409. self.clear_status(init)
  410. if sens is None:
  411. scaling_sens = self.loss_scale
  412. else:
  413. scaling_sens = sens
  414. # DP clip
  415. weights = self.weights
  416. record_datas = self._split(data)
  417. record_labels = self._split(label)
  418. # first index
  419. loss = self.network(record_datas[0], record_labels[0])
  420. scaling_sens_filled = C.ones_like(loss)*F.cast(scaling_sens,
  421. F.dtype(loss))
  422. record_grad = self.grad(self.network, weights)(record_datas[0],
  423. record_labels[0],
  424. scaling_sens_filled)
  425. beta = self._zero
  426. square_sum = self._zero
  427. for grad in record_grad:
  428. square_sum = self._add(square_sum,
  429. self._reduce_sum(self._square_all(grad)))
  430. norm_grad = self._sqrt(square_sum)
  431. beta = self._add(beta,
  432. self._cast(self._less(norm_grad, self._norm_bound),
  433. mstype.float32))
  434. record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE,
  435. self._norm_bound)
  436. grads = record_grad
  437. total_loss = loss
  438. for i in range(1, self._micro_batches):
  439. loss = self.network(record_datas[i], record_labels[i])
  440. scaling_sens_filled = C.ones_like(loss)*F.cast(scaling_sens,
  441. F.dtype(loss))
  442. record_grad = self.grad(self.network, weights)(record_datas[i],
  443. record_labels[i],
  444. scaling_sens_filled)
  445. square_sum = self._zero
  446. for grad in record_grad:
  447. square_sum = self._add(square_sum,
  448. self._reduce_sum(self._square_all(grad)))
  449. norm_grad = self._sqrt(square_sum)
  450. beta = self._add(beta,
  451. self._cast(self._less(norm_grad, self._norm_bound),
  452. mstype.float32))
  453. record_grad = self._clip_by_global_norm(record_grad,
  454. GRADIENT_CLIP_TYPE,
  455. self._norm_bound)
  456. grads = self._tuple_add(grads, record_grad)
  457. total_loss = P.Add()(total_loss, loss)
  458. loss = P.Div()(total_loss, self._micro_float)
  459. beta = self._div(beta, self._micro_batches)
  460. if self._noise_mech is not None:
  461. grad_noise_tuple = ()
  462. for grad_item in grads:
  463. grad_noise = self._noise_mech(grad_item)
  464. grad_noise_tuple = grad_noise_tuple + (grad_noise,)
  465. grads = self._tuple_add(grads, grad_noise_tuple)
  466. grads = self._hyper_map(F.partial(_grad_scale, self._micro_float),
  467. grads)
  468. # update mech parameters
  469. if self._noise_mech_param_updater is not None:
  470. multiplier = self._noise_mech_param_updater()
  471. loss = F.depend(loss, multiplier)
  472. grads = self.hyper_map(F.partial(_grad_scale, scaling_sens), grads)
  473. # apply grad reducer on grads
  474. grads = self.grad_reducer(grads)
  475. # get the overflow buffer
  476. if not self.gpu_target:
  477. self.get_status(init)
  478. # sum overflow buffer elements, 0:not overflow , >0:overflow
  479. flag_sum = self.reduce_sum(init, (0,))
  480. else:
  481. flag_sum = self.hyper_map(F.partial(_grad_overflow), grads)
  482. flag_sum = self.addn(flag_sum)
  483. # convert flag_sum to scalar
  484. flag_sum = self.reshape(flag_sum, (()))
  485. if self.is_distributed:
  486. # sum overflow flag over devices
  487. flag_reduce = self.allreduce(flag_sum)
  488. cond = self.less_equal(self.base, flag_reduce)
  489. else:
  490. cond = self.less_equal(self.base, flag_sum)
  491. overflow = cond
  492. if sens is None:
  493. overflow = self.loss_scaling_manager(self.loss_scale, cond)
  494. # if there is no overflow, do optimize
  495. if overflow:
  496. opt = False
  497. else:
  498. opt = self.optimizer(grads)
  499. ret = (loss, cond, scaling_sens)
  500. if self._clip_mech is not None:
  501. next_norm_bound = self._clip_mech(beta, self._norm_bound)
  502. P.assign(self._norm_bound, next_norm_bound)
  503. return F.depend(ret, opt)
  504. class _TrainOneStepCell(Cell):
  505. r"""
  506. Network training package class.
  507. Wraps the network with an optimizer. The resulting Cell be trained with
  508. input data and label. Backward graph will be created in the construct
  509. function to do parameter updating. Different parallel modes are available
  510. to run the training.
  511. Args:
  512. network (Cell): The training network.
  513. optimizer (Cell): Optimizer for updating the weights.
  514. sens (Number): The scaling number to be filled as the input of back
  515. propagation. Default value is 1.0.
  516. micro_batches (int): The number of small batches split from an original
  517. batch. Default: None.
  518. norm_bound (Tensor): Use to clip the bound, if set 1, will return the
  519. original data. Default: 1.0.
  520. noise_mech (Mechanisms): The object can generate the different type
  521. of noise. Default: None.
  522. Inputs:
  523. - **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
  524. - **label** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
  525. Outputs:
  526. Tensor, a scalar Tensor with shape :math:`()`.
  527. """
  528. def __init__(self, network, optimizer, norm_bound=1.0, sens=1.0,
  529. micro_batches=None,
  530. noise_mech=None, clip_mech=None):
  531. super(_TrainOneStepCell, self).__init__(auto_prefix=False)
  532. self.network = network
  533. self.network.set_grad()
  534. self.network.add_flags(defer_inline=True)
  535. self.weights = optimizer.parameters
  536. self.optimizer = optimizer
  537. self.grad = C.GradOperation(get_by_list=True, sens_param=True)
  538. self.sens = sens
  539. self.reducer_flag = False
  540. self.grad_reducer = None
  541. parallel_mode = _get_parallel_mode()
  542. if parallel_mode in (
  543. ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
  544. self.reducer_flag = True
  545. if self.reducer_flag:
  546. mean = _get_gradients_mean()
  547. degree = _get_device_num()
  548. self.grad_reducer = DistributedGradReducer(optimizer.parameters,
  549. mean, degree)
  550. # dp params
  551. if micro_batches is None:
  552. msg = 'micro_batches must give in differential privacy, but got value: {}'.format(
  553. micro_batches)
  554. LOGGER.error(TAG, msg)
  555. raise ValueError(msg)
  556. self._micro_batches = micro_batches
  557. self._norm_bound = norm_bound
  558. self._split = P.Split(0, self._micro_batches)
  559. self._clip_by_global_norm = _ClipGradients()
  560. self._noise_mech = noise_mech
  561. self._clip_mech = clip_mech
  562. self._tuple_add = _TupleAdd()
  563. self._add = P.Add()
  564. self._norm = nn.Norm()
  565. self._hyper_map = C.HyperMap()
  566. self._zero = Tensor(0, mstype.float32)
  567. self._assign = P.Assign()
  568. self._div = P.Div()
  569. self._sqrt = P.Sqrt()
  570. self._reduce_sum = P.ReduceSum()
  571. self._square_all = P.Square()
  572. self._less = P.Less()
  573. self._cast = P.Cast()
  574. self._micro_float = Tensor(micro_batches, mstype.float32)
  575. self._noise_mech_param_updater = None
  576. if self._noise_mech is not None and self._noise_mech._decay_policy is not None:
  577. self._noise_mech_param_updater = _MechanismsParamsUpdater(
  578. decay_policy=self._noise_mech._decay_policy,
  579. decay_rate=self._noise_mech._noise_decay_rate,
  580. cur_noise_multiplier=
  581. self._noise_mech._noise_multiplier,
  582. init_noise_multiplier=
  583. self._noise_mech._initial_noise_multiplier)
  584. def construct(self, data, label):
  585. """
  586. construct a compute flow.
  587. """
  588. weights = self.weights
  589. record_datas = self._split(data)
  590. record_labels = self._split(label)
  591. loss = self.network(record_datas[0], record_labels[0])
  592. sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
  593. record_grad = self.grad(self.network, weights)(record_datas[0],
  594. record_labels[0], sens)
  595. beta = self._zero
  596. # calcu beta
  597. if self._clip_mech is not None:
  598. square_sum = self._zero
  599. for grad in record_grad:
  600. square_sum = self._add(square_sum,
  601. self._reduce_sum(self._square_all(grad)))
  602. norm_grad = self._sqrt(square_sum)
  603. beta = self._add(beta,
  604. self._cast(self._less(norm_grad, self._norm_bound),
  605. mstype.float32))
  606. record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE,
  607. self._norm_bound)
  608. grads = record_grad
  609. total_loss = loss
  610. for i in range(1, self._micro_batches):
  611. loss = self.network(record_datas[i], record_labels[i])
  612. sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
  613. record_grad = self.grad(self.network, weights)(record_datas[i],
  614. record_labels[i],
  615. sens)
  616. # calcu beta
  617. if self._clip_mech is not None:
  618. square_sum = self._zero
  619. for grad in record_grad:
  620. square_sum = self._add(square_sum,
  621. self._reduce_sum(self._square_all(grad)))
  622. norm_grad = self._sqrt(square_sum)
  623. beta = self._add(beta,
  624. self._cast(self._less(norm_grad, self._norm_bound),
  625. mstype.float32))
  626. record_grad = self._clip_by_global_norm(record_grad,
  627. GRADIENT_CLIP_TYPE,
  628. self._norm_bound)
  629. grads = self._tuple_add(grads, record_grad)
  630. total_loss = P.Add()(total_loss, loss)
  631. loss = self._div(total_loss, self._micro_float)
  632. if self._noise_mech is not None:
  633. grad_noise_tuple = ()
  634. for grad_item in grads:
  635. grad_noise = self._noise_mech(grad_item)
  636. grad_noise_tuple = grad_noise_tuple + (grad_noise,)
  637. grads = self._tuple_add(grads, grad_noise_tuple)
  638. grads = self._hyper_map(F.partial(_grad_scale, self._micro_float),
  639. grads)
  640. # update mech parameters
  641. if self._noise_mech_param_updater is not None:
  642. multiplier = self._noise_mech_param_updater()
  643. loss = F.depend(loss, multiplier)
  644. if self.reducer_flag:
  645. # apply grad reducer on grads
  646. grads = self.grad_reducer(grads)
  647. if self._clip_mech is not None:
  648. beta = self._div(beta, self._micro_batches)
  649. next_norm_bound = self._clip_mech(beta, self._norm_bound)
  650. self._norm_bound = self._assign(self._norm_bound, next_norm_bound)
  651. loss = F.depend(loss, self._norm_bound)
  652. return F.depend(loss, self.optimizer(grads))

MindArmour关注AI的安全和隐私问题。致力于增强模型的安全可信、保护用户的数据隐私。主要包含3个模块:对抗样本鲁棒性模块、Fuzz Testing模块、隐私保护与评估模块。