diff --git a/mindarmour/diff_privacy/train/model.py b/mindarmour/diff_privacy/train/model.py index ea22d14..17c4927 100644 --- a/mindarmour/diff_privacy/train/model.py +++ b/mindarmour/diff_privacy/train/model.py @@ -303,6 +303,7 @@ class _TrainOneStepWithLossScaleCell(Cell): def __init__(self, network, optimizer, scale_update_cell=None, micro_batches=None, l2_norm_clip=1.0, mech=None): super(_TrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False) self.network = network + self.network.set_grad() self.network.add_flags(defer_inline=True) self.weights = ParameterTuple(network.trainable_params()) self.optimizer = optimizer @@ -450,6 +451,7 @@ class _TrainOneStepCell(Cell): def __init__(self, network, optimizer, sens=1.0, micro_batches=None, l2_norm_clip=1.0, mech=None): super(_TrainOneStepCell, self).__init__(auto_prefix=False) self.network = network + self.network.set_grad() self.network.add_flags(defer_inline=True) self.weights = optimizer.parameters self.optimizer = optimizer