You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_hook.py 5.5 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import numpy as np
  16. import mindspore.nn as nn
  17. import mindspore.ops.operations as P
  18. from mindspore import context, Tensor, ParameterTuple
  19. from mindspore.common.initializer import TruncatedNormal
  20. from mindspore.nn import WithLossCell, Momentum
  21. from mindspore.ops import composite as C
  22. context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
  23. cell_hook_done = False
  24. var_hook_done = False
  25. cell_bprop_done = False
  26. def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
  27. """weight initial for conv layer"""
  28. weight = weight_variable()
  29. return nn.Conv2d(in_channels, out_channels,
  30. kernel_size=kernel_size, stride=stride, padding=padding,
  31. weight_init=weight, has_bias=False, pad_mode="valid")
  32. def fc_with_initialize(input_channels, out_channels):
  33. """weight initial for fc layer"""
  34. weight = weight_variable()
  35. bias = weight_variable()
  36. return nn.Dense(input_channels, out_channels, weight, bias)
  37. def weight_variable():
  38. """weight initial"""
  39. return TruncatedNormal(0.02)
  40. def cell_hook_function(cell_id, grad_input, grad_output):
  41. print(cell_id)
  42. global cell_hook_done
  43. cell_hook_done = True
  44. assert (grad_output[0].asnumpy().shape == (32, 6, 14, 14))
  45. assert (grad_input[0].asnumpy().shape == (32, 16, 10, 10))
  46. def var_hook_function(grad_out):
  47. print("grad:", grad_out)
  48. global var_hook_done
  49. var_hook_done = True
  50. assert (grad_out[0].asnumpy().shape == (32, 120))
  51. class Block(nn.Cell):
  52. def __init__(self):
  53. super(Block, self).__init__()
  54. self.relu = nn.ReLU()
  55. def construct(self, x):
  56. x = self.relu(x)
  57. return x
  58. def bprop(self, x, out, dout):
  59. global cell_bprop_done
  60. cell_bprop_done = True
  61. grad = out.asnumpy() * dout.asnumpy()
  62. grad = Tensor(grad)
  63. return (grad,)
  64. class LeNet5(nn.Cell):
  65. """
  66. Lenet network
  67. Args:
  68. num_class (int): Num classes. Default: 10.
  69. Returns:
  70. Tensor, output tensor
  71. Examples:
  72. >>> LeNet(num_class=10)
  73. """
  74. def __init__(self, num_class=10):
  75. super(LeNet5, self).__init__()
  76. self.num_class = num_class
  77. self.batch_size = 32
  78. self.conv1 = conv(1, 6, 5)
  79. self.conv2 = conv(6, 16, 5)
  80. self.conv2.register_backward_hook(cell_hook_function)
  81. self.block = Block()
  82. self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
  83. self.fc2 = fc_with_initialize(120, 84)
  84. self.fc3 = fc_with_initialize(84, self.num_class)
  85. self.relu = nn.ReLU()
  86. self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
  87. self.reshape = P.Reshape()
  88. self.hook = P.HookBackward(var_hook_function)
  89. def construct(self, x):
  90. x = self.conv1(x)
  91. x = self.relu(x)
  92. x = self.max_pool2d(x)
  93. x = self.conv2(x)
  94. x = self.block(x)
  95. x = self.max_pool2d(x)
  96. x = self.reshape(x, (self.batch_size, -1))
  97. x = self.fc1(x)
  98. x = self.hook(x)
  99. x = self.relu(x)
  100. x = self.fc2(x)
  101. x = self.relu(x)
  102. x = self.fc3(x)
  103. return x
  104. class GradWrap(nn.Cell):
  105. """ GradWrap definition """
  106. def __init__(self, network):
  107. super(GradWrap, self).__init__(auto_prefix=False)
  108. self.network = network
  109. self.weights = ParameterTuple(filter(lambda x: x.requires_grad, network.get_parameters()))
  110. def construct(self, x, label):
  111. weights = self.weights
  112. return C.GradOperation('get_by_list', get_by_list=True)(self.network, weights)(x, label)
  113. def test_hook():
  114. net = LeNet5()
  115. optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.1, 0.9)
  116. criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=False)
  117. net_with_criterion = WithLossCell(net, criterion)
  118. train_network = GradWrap(net_with_criterion)
  119. train_network.set_train()
  120. input_data = Tensor(np.ones([net.batch_size, 1, 32, 32]).astype(np.float32) * 0.01)
  121. label = Tensor(np.ones([net.batch_size, net.num_class]).astype(np.float32))
  122. output = net(Tensor(input_data))
  123. loss_output = criterion(output, label)
  124. grads = train_network(input_data, label)
  125. success = optimizer(grads)
  126. assert cell_hook_done
  127. assert var_hook_done
  128. assert cell_bprop_done
  129. print(loss_output.asnumpy().shape)
  130. class MulAdd(nn.Cell):
  131. def __init__(self):
  132. super(MulAdd, self).__init__()
  133. def construct(self, x, y):
  134. return 2 * x + y
  135. def bprop(self, x, y, out, dout):
  136. assert (x == 1)
  137. assert (y == 2)
  138. assert (out == 4)
  139. assert (dout == 1)
  140. return 3 * dout, 2 * y
  141. def test_custom_bprop():
  142. mul_add = MulAdd()
  143. mul_add.bprop_debug = True
  144. assert C.grad_all(mul_add)(1, 2) == (3, 4)