|
@@ -303,6 +303,7 @@ class _TrainOneStepWithLossScaleCell(Cell): |
|
|
def __init__(self, network, optimizer, scale_update_cell=None, micro_batches=None, l2_norm_clip=1.0, mech=None): |
|
|
def __init__(self, network, optimizer, scale_update_cell=None, micro_batches=None, l2_norm_clip=1.0, mech=None): |
|
|
super(_TrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False) |
|
|
super(_TrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False) |
|
|
self.network = network |
|
|
self.network = network |
|
|
|
|
|
self.network.set_grad() |
|
|
self.network.add_flags(defer_inline=True) |
|
|
self.network.add_flags(defer_inline=True) |
|
|
self.weights = ParameterTuple(network.trainable_params()) |
|
|
self.weights = ParameterTuple(network.trainable_params()) |
|
|
self.optimizer = optimizer |
|
|
self.optimizer = optimizer |
|
@@ -450,6 +451,7 @@ class _TrainOneStepCell(Cell): |
|
|
def __init__(self, network, optimizer, sens=1.0, micro_batches=None, l2_norm_clip=1.0, mech=None): |
|
|
def __init__(self, network, optimizer, sens=1.0, micro_batches=None, l2_norm_clip=1.0, mech=None): |
|
|
super(_TrainOneStepCell, self).__init__(auto_prefix=False) |
|
|
super(_TrainOneStepCell, self).__init__(auto_prefix=False) |
|
|
self.network = network |
|
|
self.network = network |
|
|
|
|
|
self.network.set_grad() |
|
|
self.network.add_flags(defer_inline=True) |
|
|
self.network.add_flags(defer_inline=True) |
|
|
self.weights = optimizer.parameters |
|
|
self.weights = optimizer.parameters |
|
|
self.optimizer = optimizer |
|
|
self.optimizer = optimizer |
|
|