You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

var_init.py 8.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """weight initial"""
  16. import math
  17. import numpy as np
  18. from mindspore.common import initializer as init
  19. import mindspore.nn as nn
  20. from mindspore import Tensor
  21. def calculate_gain(nonlinearity, param=None):
  22. r"""Return the recommended gain value for the given nonlinearity function.
  23. The values are as follows:
  24. ================= ====================================================
  25. nonlinearity gain
  26. ================= ====================================================
  27. Linear / Identity :math:`1`
  28. Conv{1,2,3}D :math:`1`
  29. Sigmoid :math:`1`
  30. Tanh :math:`\frac{5}{3}`
  31. ReLU :math:`\sqrt{2}`
  32. Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}`
  33. ================= ====================================================
  34. Args:
  35. nonlinearity: the non-linear function (`nn.functional` name)
  36. param: optional parameter for the non-linear function
  37. """
  38. linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
  39. gain = 0
  40. if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
  41. gain = 1
  42. elif nonlinearity == 'tanh':
  43. gain = 5.0 / 3
  44. elif nonlinearity == 'relu':
  45. gain = math.sqrt(2.0)
  46. elif nonlinearity == 'leaky_relu':
  47. if param is None:
  48. negative_slope = 0.01
  49. elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
  50. # True/False are instances of int, hence check above
  51. negative_slope = param
  52. else:
  53. raise ValueError("negative_slope {} not a valid number".format(param))
  54. gain = math.sqrt(2.0 / (1 + negative_slope ** 2))
  55. else:
  56. raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
  57. return gain
  58. def _calculate_correct_fan(array, mode):
  59. mode = mode.lower()
  60. valid_modes = ['fan_in', 'fan_out']
  61. if mode not in valid_modes:
  62. raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
  63. fan_in, fan_out = _calculate_fan_in_and_fan_out(array)
  64. return fan_in if mode == 'fan_in' else fan_out
  65. def kaiming_uniform_(array, a=0, mode='fan_in', nonlinearity='leaky_relu'):
  66. r"""Fills the input `Tensor` with values according to the method
  67. described in `Delving deep into rectifiers: Surpassing human-level
  68. performance on ImageNet classification` - He, K. et al. (2015), using a
  69. uniform distribution. The resulting tensor will have values sampled from
  70. :math:`\mathcal{U}(-\text{bound}, \text{bound})` where
  71. .. math::
  72. \text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}}
  73. Also known as He initialization.
  74. Args:
  75. array: an n-dimensional `tensor`
  76. a: the negative slope of the rectifier used after this layer (only
  77. used with ``'leaky_relu'``)
  78. mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
  79. preserves the magnitude of the variance of the weights in the
  80. forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
  81. backwards pass.
  82. nonlinearity: the non-linear function (`nn.functional` name),
  83. recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
  84. """
  85. fan = _calculate_correct_fan(array, mode)
  86. gain = calculate_gain(nonlinearity, a)
  87. std = gain / math.sqrt(fan)
  88. bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
  89. return np.random.uniform(-bound, bound, array.shape)
  90. def kaiming_normal_(array, a=0, mode='fan_in', nonlinearity='leaky_relu'):
  91. r"""Fills the input `Tensor` with values according to the method
  92. described in `Delving deep into rectifiers: Surpassing human-level
  93. performance on ImageNet classification` - He, K. et al. (2015), using a
  94. normal distribution. The resulting tensor will have values sampled from
  95. :math:`\mathcal{N}(0, \text{std}^2)` where
  96. .. math::
  97. \text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}}
  98. Also known as He initialization.
  99. Args:
  100. array: an n-dimensional `tensor`
  101. a: the negative slope of the rectifier used after this layer (only
  102. used with ``'leaky_relu'``)
  103. mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
  104. preserves the magnitude of the variance of the weights in the
  105. forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
  106. backwards pass.
  107. nonlinearity: the non-linear function (`nn.functional` name),
  108. recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
  109. """
  110. fan = _calculate_correct_fan(array, mode)
  111. gain = calculate_gain(nonlinearity, a)
  112. std = gain / math.sqrt(fan)
  113. return np.random.normal(0, std, array.shape)
  114. def _calculate_fan_in_and_fan_out(array):
  115. """calculate the fan_in and fan_out for input array"""
  116. dimensions = len(array.shape)
  117. if dimensions < 2:
  118. raise ValueError("Fan in and fan out can not be computed for array with fewer than 2 dimensions")
  119. num_input_fmaps = array.shape[1]
  120. num_output_fmaps = array.shape[0]
  121. receptive_field_size = 1
  122. if dimensions > 2:
  123. receptive_field_size = array[0][0].size
  124. fan_in = num_input_fmaps * receptive_field_size
  125. fan_out = num_output_fmaps * receptive_field_size
  126. return fan_in, fan_out
  127. def assignment(arr, num):
  128. """Assign the value of num to arr"""
  129. if arr.shape == ():
  130. arr = arr.reshape((1))
  131. arr[:] = num
  132. arr = arr.reshape(())
  133. else:
  134. if isinstance(num, np.ndarray):
  135. arr[:] = num[:]
  136. else:
  137. arr[:] = num
  138. return arr
  139. class KaimingUniform(init.Initializer):
  140. def __init__(self, a=0, mode='fan_in', nonlinearity='leaky_relu'):
  141. super(KaimingUniform, self).__init__()
  142. self.a = a
  143. self.mode = mode
  144. self.nonlinearity = nonlinearity
  145. def _initialize(self, arr):
  146. tmp = kaiming_uniform_(arr, self.a, self.mode, self.nonlinearity)
  147. assignment(arr, tmp)
  148. class KaimingNormal(init.Initializer):
  149. def __init__(self, a=0, mode='fan_in', nonlinearity='leaky_relu'):
  150. super(KaimingNormal, self).__init__()
  151. self.a = a
  152. self.mode = mode
  153. self.nonlinearity = nonlinearity
  154. def _initialize(self, arr):
  155. tmp = kaiming_normal_(arr, self.a, self.mode, self.nonlinearity)
  156. assignment(arr, tmp)
  157. def default_recurisive_init(custom_cell):
  158. """weight init for conv2d and dense"""
  159. for _, cell in custom_cell.cells_and_names():
  160. if isinstance(cell, nn.Conv2d):
  161. cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)),
  162. cell.weight.default_input.shape(),
  163. cell.weight.default_input.dtype())
  164. if cell.bias is not None:
  165. fan_in, _ = _calculate_fan_in_and_fan_out(cell.weight.default_input.asnumpy())
  166. bound = 1 / math.sqrt(fan_in)
  167. cell.bias.default_input = Tensor(np.random.uniform(-bound, bound,
  168. cell.bias.default_input.shape()),
  169. cell.bias.default_input.dtype())
  170. elif isinstance(cell, nn.Dense):
  171. cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)),
  172. cell.weight.default_input.shape(),
  173. cell.weight.default_input.dtype())
  174. if cell.bias is not None:
  175. fan_in, _ = _calculate_fan_in_and_fan_out(cell.weight.default_input.asnumpy())
  176. bound = 1 / math.sqrt(fan_in)
  177. cell.bias.default_input = Tensor(np.random.uniform(-bound, bound,
  178. cell.bias.default_input.shape()),
  179. cell.bias.default_input.dtype())
  180. elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)):
  181. pass