|
|
@@ -34,7 +34,7 @@ class MyTrainOneStepCell(nn.TrainOneStepCell): |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, network, optimizer, sens=1.0): |
|
|
|
super(MyTrainOneStepCell, self).__init__(network, optimizer,sens) |
|
|
|
super(MyTrainOneStepCell, self).__init__(network, optimizer, sens) |
|
|
|
self.grad = ops.composite.GradOperation(get_all=True, sens_param=False) |
|
|
|
|
|
|
|
def construct(self, *inputs): |
|
|
|