You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

BatchNormalization.cs 9.1 kB

4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. /*****************************************************************************
  2. Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. ******************************************************************************/
  13. using System;
  14. using System.Collections.Generic;
  15. using System.Linq;
  16. using Tensorflow.Keras.ArgsDefinition;
  17. using Tensorflow.Keras.Engine;
  18. using Tensorflow.Keras.Utils;
  19. using static Tensorflow.Binding;
  20. namespace Tensorflow.Keras.Layers
  21. {
  22. public class BatchNormalization : Layer
  23. {
  24. BatchNormalizationArgs args;
  25. float momentum => args.Momentum;
  26. float epsilon => args.Epsilon;
  27. bool center => args.Center;
  28. bool scale => args.Scale;
  29. bool renorm => args.Renorm;
  30. bool fused;
  31. int[] axis;
  32. string _data_format;
  33. IInitializer beta_initializer => args.BetaInitializer;
  34. IInitializer gamma_initializer => args.GammaInitializer;
  35. IInitializer moving_mean_initializer => args.MovingMeanInitializer;
  36. IInitializer moving_variance_initializer => args.MovingVarianceInitializer;
  37. IRegularizer gamma_regularizer => args.GammaRegularizer;
  38. IVariableV1 gamma;
  39. IVariableV1 beta;
  40. IVariableV1 moving_mean;
  41. IVariableV1 moving_variance;
  42. public BatchNormalization(BatchNormalizationArgs args) : base(args)
  43. {
  44. this.args = args;
  45. axis = args.Axis.dims;
  46. }
  47. protected override void build(TensorShape input_shape)
  48. {
  49. var ndims = input_shape.ndim;
  50. foreach (var (idx, x) in enumerate(axis))
  51. if (x < 0)
  52. axis[idx] = ndims + x;
  53. fused = ndims == 4;
  54. if (fused)
  55. {
  56. if (Enumerable.SequenceEqual(axis, new int[] { 1 }))
  57. _data_format = "NCHW";
  58. else if (Enumerable.SequenceEqual(axis, new int[] { 3 }))
  59. _data_format = "NHWC";
  60. else
  61. throw new ValueError($"Unsupported axis, fused batch norm only supports axis == [1] or axis == [3]");
  62. }
  63. var axis_to_dim = new Dictionary<int, int>();
  64. foreach (var x in axis)
  65. axis_to_dim[x] = input_shape[x];
  66. inputSpec = new InputSpec(ndim: ndims, axes: axis_to_dim);
  67. var param_dtype = DType == TF_DataType.DtInvalid ? TF_DataType.TF_FLOAT : DType;
  68. var param_shape = inputSpec.AllAxisDim;
  69. if (scale)
  70. gamma = add_weight("gamma",
  71. param_shape,
  72. dtype: param_dtype,
  73. initializer: gamma_initializer,
  74. trainable: true);
  75. else
  76. throw new NotImplementedException("add_weight gamma");
  77. if (center)
  78. beta = add_weight("beta",
  79. param_shape,
  80. dtype: param_dtype,
  81. initializer: beta_initializer,
  82. trainable: true);
  83. else
  84. throw new NotImplementedException("add_weight beta");
  85. moving_mean = add_weight("moving_mean",
  86. param_shape,
  87. dtype: param_dtype,
  88. initializer: moving_mean_initializer,
  89. synchronization: VariableSynchronization.OnRead,
  90. aggregation: VariableAggregation.Mean,
  91. trainable: false);
  92. moving_variance = add_weight("moving_variance",
  93. shape: param_shape,
  94. dtype: param_dtype,
  95. initializer: moving_variance_initializer,
  96. synchronization: VariableSynchronization.OnRead,
  97. aggregation: VariableAggregation.Mean,
  98. trainable: false);
  99. if (renorm)
  100. throw new NotImplementedException("build when renorm is true");
  101. built = true;
  102. }
  103. protected override Tensors Call(Tensors inputs, Tensor state = null, bool training = false)
  104. {
  105. Tensor outputs = null;
  106. var training_tensor = tf.logical_and(training, Trainable);
  107. if (fused)
  108. {
  109. // var training = tf.convert_to_tensor(training);
  110. outputs = _fused_batch_norm(inputs, training: training_tensor);
  111. return outputs;
  112. }
  113. throw new NotImplementedException("BatchNormalization call");
  114. }
  115. private Tensor _fused_batch_norm(Tensor inputs, Tensor training)
  116. {
  117. TensorShape input_batch_size = null;
  118. var use_fused_avg_updates = true;
  119. float exponential_avg_factor = 0;
  120. if (use_fused_avg_updates)
  121. exponential_avg_factor = 1.0f - momentum;
  122. var beta = this.beta;
  123. var gamma = this.gamma;
  124. Func<Tensor[]> _fused_batch_norm_training = () =>
  125. {
  126. return tf.nn.fused_batch_norm(
  127. inputs,
  128. gamma,
  129. beta,
  130. mean: moving_mean,
  131. variance: moving_variance,
  132. epsilon: epsilon, is_training: true,
  133. data_format: _data_format,
  134. exponential_avg_factor: exponential_avg_factor);
  135. };
  136. Func<Tensor[]> _fused_batch_norm_inference = () =>
  137. {
  138. return tf.nn.fused_batch_norm(
  139. inputs,
  140. gamma,
  141. beta,
  142. mean: moving_mean,
  143. variance: moving_variance,
  144. epsilon: epsilon,
  145. is_training: false,
  146. data_format: _data_format);
  147. };
  148. if (use_fused_avg_updates && input_batch_size != null)
  149. throw new NotImplementedException("");
  150. var results = tf_utils.smart_cond(training, _fused_batch_norm_training, _fused_batch_norm_inference);
  151. var (output, mean, variance) = (results[0], results[1], results[2]);
  152. var training_value = tf_utils.constant_value(training);
  153. if (!training_value.HasValue || (training_value.HasValue && training_value.Value))
  154. {
  155. Tensor momentum_tensor = null;
  156. if (!use_fused_avg_updates)
  157. {
  158. if (training_value == null)
  159. momentum_tensor = tf_utils.smart_cond(training,
  160. () => new float[] { momentum },
  161. () => new float[] { 1.0f })[0];
  162. else
  163. momentum_tensor = ops.convert_to_tensor(momentum);
  164. }
  165. if (use_fused_avg_updates)
  166. _assign_new_value(moving_mean, mean);
  167. else
  168. _assign_moving_average(moving_variance, variance, momentum_tensor);
  169. if (use_fused_avg_updates)
  170. _assign_new_value(moving_variance, mean);
  171. else
  172. _assign_moving_average(moving_variance, variance, momentum_tensor);
  173. // var mean_update = _assign_moving_average(moving_mean.AsTensor(), mean, momentum_tensor);
  174. // var variance_update = _assign_moving_average(moving_variance.AsTensor(), variance, momentum_tensor);
  175. // add_update(new Tensor[] { mean_update }, inputs: true);
  176. // add_update(new Tensor[] { variance_update }, inputs: true);
  177. }
  178. return output;
  179. }
  180. Tensor _assign_new_value(IVariableV1 variable, Tensor value)
  181. {
  182. return tf_with(ops.name_scope("AssignNewValue", null, new { variable, value, momentum }), scope =>
  183. {
  184. // var cm = ops.colocate_with(variable);
  185. return state_ops.assign_sub(variable, value, name: scope);
  186. });
  187. }
  188. Tensor _assign_moving_average(IVariableV1 variable, Tensor value, Tensor momentum)
  189. {
  190. return tf_with(ops.name_scope("AssignMovingAvg", null, new { variable, value, momentum }), scope =>
  191. {
  192. // var cm = ops.colocate_with(variable);
  193. var decay = ops.convert_to_tensor(1.0f - momentum, name: "decay");
  194. var update_delta = (variable.AsTensor() - math_ops.cast(value, variable.dtype)) * decay;
  195. return state_ops.assign_sub(variable, update_delta, name: scope);
  196. });
  197. }
  198. }
  199. }