Browse Source

!36 fix issue

Merge pull request !36 from zheng-huanhuan/master
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
3f9c6901c0
1 changed files with 2 additions and 0 deletions
  1. +2
    -0
      mindarmour/diff_privacy/train/model.py

+ 2
- 0
mindarmour/diff_privacy/train/model.py View File

@@ -303,6 +303,7 @@ class _TrainOneStepWithLossScaleCell(Cell):
def __init__(self, network, optimizer, scale_update_cell=None, micro_batches=None, l2_norm_clip=1.0, mech=None): def __init__(self, network, optimizer, scale_update_cell=None, micro_batches=None, l2_norm_clip=1.0, mech=None):
super(_TrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False) super(_TrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
self.network = network self.network = network
self.network.set_grad()
self.network.add_flags(defer_inline=True) self.network.add_flags(defer_inline=True)
self.weights = ParameterTuple(network.trainable_params()) self.weights = ParameterTuple(network.trainable_params())
self.optimizer = optimizer self.optimizer = optimizer
@@ -450,6 +451,7 @@ class _TrainOneStepCell(Cell):
def __init__(self, network, optimizer, sens=1.0, micro_batches=None, l2_norm_clip=1.0, mech=None): def __init__(self, network, optimizer, sens=1.0, micro_batches=None, l2_norm_clip=1.0, mech=None):
super(_TrainOneStepCell, self).__init__(auto_prefix=False) super(_TrainOneStepCell, self).__init__(auto_prefix=False)
self.network = network self.network = network
self.network.set_grad()
self.network.add_flags(defer_inline=True) self.network.add_flags(defer_inline=True)
self.weights = optimizer.parameters self.weights = optimizer.parameters
self.optimizer = optimizer self.optimizer = optimizer


Loading…
Cancel
Save