You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

BatchNormalization.cs 12 kB

4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302
  1. /*****************************************************************************
  2. Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. ******************************************************************************/
  13. using System;
  14. using System.Collections.Generic;
  15. using System.Linq;
  16. using Tensorflow.Common.Types;
  17. using Tensorflow.Keras.ArgsDefinition;
  18. using Tensorflow.Keras.Engine;
  19. using Tensorflow.Keras.Saving;
  20. using Tensorflow.Keras.Utils;
  21. using static Tensorflow.Binding;
  22. namespace Tensorflow.Keras.Layers
  23. {
  24. public class BatchNormalization : Layer
  25. {
  26. BatchNormalizationArgs args;
  27. float momentum => args.Momentum;
  28. float epsilon => args.Epsilon;
  29. bool center => args.Center;
  30. bool scale => args.Scale;
  31. bool renorm => args.Renorm;
  32. bool fused;
  33. int[] axis;
  34. string _data_format;
  35. Shape kernel_size;
  36. IInitializer beta_initializer => args.BetaInitializer;
  37. IInitializer gamma_initializer => args.GammaInitializer;
  38. IInitializer moving_mean_initializer => args.MovingMeanInitializer;
  39. IInitializer moving_variance_initializer => args.MovingVarianceInitializer;
  40. IRegularizer gamma_regularizer => args.GammaRegularizer;
  41. IVariableV1 gamma;
  42. IVariableV1 beta;
  43. IVariableV1 moving_mean;
  44. IVariableV1 moving_variance;
  45. public BatchNormalization(BatchNormalizationArgs args) : base(args)
  46. {
  47. this.args = args;
  48. axis = args.Axis.dims.Select(x => (int)x).ToArray();
  49. }
  50. public override void build(KerasShapesWrapper input_shape)
  51. {
  52. var single_shape = input_shape.ToSingleShape();
  53. var ndims = single_shape.ndim;
  54. foreach (var (idx, x) in enumerate(axis))
  55. if (x < 0)
  56. args.Axis.dims[idx] = axis[idx] = ndims + x;
  57. fused = ndims == 4;
  58. if (fused)
  59. {
  60. if (Enumerable.SequenceEqual(axis, new int[] { 1 }))
  61. _data_format = "NCHW";
  62. else if (Enumerable.SequenceEqual(axis, new int[] { 3 }))
  63. _data_format = "NHWC";
  64. else
  65. throw new ValueError($"Unsupported axis, fused batch norm only supports axis == [1] or axis == [3]");
  66. }
  67. var axis_to_dim = new Dictionary<int, int>();
  68. foreach (var x in axis)
  69. axis_to_dim[x] = (int)single_shape[x];
  70. inputSpec = new InputSpec(ndim: ndims, axes: axis_to_dim);
  71. var param_dtype = DType == TF_DataType.DtInvalid ? TF_DataType.TF_FLOAT : DType;
  72. var param_shape = inputSpec.AllAxisDim;
  73. if (scale)
  74. gamma = add_weight("gamma",
  75. param_shape,
  76. dtype: param_dtype,
  77. initializer: gamma_initializer,
  78. trainable: true);
  79. else
  80. throw new NotImplementedException("add_weight gamma");
  81. if (center)
  82. beta = add_weight("beta",
  83. param_shape,
  84. dtype: param_dtype,
  85. initializer: beta_initializer,
  86. trainable: true);
  87. else
  88. throw new NotImplementedException("add_weight beta");
  89. moving_mean = add_weight("moving_mean",
  90. param_shape,
  91. dtype: param_dtype,
  92. initializer: moving_mean_initializer,
  93. synchronization: VariableSynchronization.OnRead,
  94. aggregation: VariableAggregation.Mean,
  95. trainable: false);
  96. moving_variance = add_weight("moving_variance",
  97. shape: param_shape,
  98. dtype: param_dtype,
  99. initializer: moving_variance_initializer,
  100. synchronization: VariableSynchronization.OnRead,
  101. aggregation: VariableAggregation.Mean,
  102. trainable: false);
  103. if (renorm)
  104. throw new NotImplementedException("build when renorm is true");
  105. built = true;
  106. _buildInputShape = input_shape;
  107. }
  108. public override Shape ComputeOutputShape(Shape input_shape)
  109. {
  110. return input_shape;
  111. }
  112. (Tensor, Tensor) _moments(Tensors inputs, int[] reduction_axes, bool keep_dims)
  113. {
  114. var (mean, variance) = _calculate_mean_and_var(inputs, reduction_axes, keep_dims);
  115. if (_support_zero_size_input())
  116. throw new NotImplementedException("");
  117. return (mean, variance);
  118. }
  119. (Tensor, Tensor) _calculate_mean_and_var(Tensors inputs, int[] reduction_axes, bool keep_dims)
  120. {
  121. return nn_impl.moments(inputs, reduction_axes, keep_dims: keep_dims);
  122. }
  123. bool _support_zero_size_input()
  124. {
  125. return false;
  126. }
  127. protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
  128. {
  129. Tensor outputs = null;
  130. var training_tensor = training == null
  131. ? tf.placeholder(tf.@bool, Shape.Scalar)
  132. : tf.logical_and(training.Value, Trainable);
  133. if (fused)
  134. {
  135. // var training = tf.convert_to_tensor(training);
  136. outputs = _fused_batch_norm(inputs, training: training_tensor);
  137. return outputs;
  138. }
  139. var inputs_dtype = inputs.dtype.as_base_dtype();
  140. var input_shape = inputs.shape;
  141. var ndims = len(input_shape);
  142. var reduction_axes = range(ndims).Where(x => !axis.Contains(x)).ToArray();
  143. // Broadcasting only necessary for single-axis batch norm where the axis is
  144. // not the last dimension
  145. var broadcast_shape = range(ndims).Select(x => 1).ToArray();
  146. broadcast_shape[axis[0]] = (int)input_shape.dims[axis[0]];
  147. var (scale, offset) = (gamma, beta);
  148. var training_value = tf_utils.constant_value(training_tensor);
  149. Tensor mean;
  150. Tensor variance;
  151. if (training_value.HasValue && training_value.Value == false)
  152. {
  153. (mean, variance) = (moving_mean.AsTensor(), moving_variance.AsTensor());
  154. }
  155. else
  156. {
  157. var keep_dims = len(axis) > 1;
  158. (mean, variance) = _moments(inputs, reduction_axes, keep_dims: keep_dims);
  159. mean = tf_utils.smart_cond(training_tensor,
  160. () => new[] { mean },
  161. () => new[] { ops.convert_to_tensor(moving_mean) }).FirstOrDefault();
  162. variance = tf_utils.smart_cond(training_tensor,
  163. () => new[] { variance },
  164. () => new[] { ops.convert_to_tensor(moving_variance) }).FirstOrDefault();
  165. var (new_mean, new_variance) = (mean, variance);
  166. }
  167. mean = math_ops.cast(mean, inputs.dtype);
  168. variance = math_ops.cast(variance, inputs.dtype);
  169. var offset_tensor = math_ops.cast(offset, inputs.dtype);
  170. var scale_tensor = math_ops.cast(scale, inputs.dtype);
  171. outputs = nn_impl.batch_normalization(inputs, mean, variance,
  172. offset_tensor, scale_tensor, epsilon);
  173. // If some components of the shape got lost due to adjustments, fix that.
  174. outputs.shape = input_shape;
  175. return outputs;
  176. }
  177. private Tensor _fused_batch_norm(Tensor inputs, Tensor training)
  178. {
  179. Shape input_batch_size = null;
  180. var use_fused_avg_updates = true;
  181. float exponential_avg_factor = 0;
  182. if (use_fused_avg_updates)
  183. exponential_avg_factor = 1.0f - momentum;
  184. Func<Tensor[]> _fused_batch_norm_training = () =>
  185. {
  186. return tf.nn.fused_batch_norm(
  187. inputs,
  188. gamma.AsTensor(),
  189. beta.AsTensor(),
  190. mean: moving_mean.AsTensor(),
  191. variance: moving_variance.AsTensor(),
  192. epsilon: epsilon,
  193. is_training: true,
  194. data_format: _data_format,
  195. exponential_avg_factor: exponential_avg_factor);
  196. };
  197. Func<Tensor[]> _fused_batch_norm_inference = () =>
  198. {
  199. return tf.nn.fused_batch_norm(
  200. inputs,
  201. gamma.AsTensor(),
  202. beta.AsTensor(),
  203. mean: moving_mean.AsTensor(),
  204. variance: moving_variance.AsTensor(),
  205. epsilon: epsilon,
  206. is_training: false,
  207. data_format: _data_format);
  208. };
  209. if (use_fused_avg_updates && input_batch_size != null)
  210. throw new NotImplementedException("");
  211. var results = tf_utils.smart_cond(training, _fused_batch_norm_training, _fused_batch_norm_inference);
  212. var (output, mean, variance) = (results[0], results[1], results[2]);
  213. var training_value = tf_utils.constant_value(training);
  214. if (!training_value.HasValue || (training_value.HasValue && training_value.Value))
  215. {
  216. Tensor momentum_tensor = null;
  217. if (!use_fused_avg_updates)
  218. {
  219. if (training_value == null)
  220. momentum_tensor = tf_utils.smart_cond(training,
  221. () => new float[] { momentum },
  222. () => new float[] { 1.0f })[0];
  223. else
  224. momentum_tensor = ops.convert_to_tensor(momentum);
  225. }
  226. if (use_fused_avg_updates)
  227. _assign_new_value(moving_mean, mean);
  228. else
  229. _assign_moving_average(moving_variance, variance, momentum_tensor);
  230. if (use_fused_avg_updates)
  231. _assign_new_value(moving_variance, variance);
  232. else
  233. _assign_moving_average(moving_variance, variance, momentum_tensor);
  234. // var mean_update = _assign_moving_average(moving_mean.AsTensor(), mean, momentum_tensor);
  235. // var variance_update = _assign_moving_average(moving_variance.AsTensor(), variance, momentum_tensor);
  236. // add_update(new Tensor[] { mean_update }, inputs: true);
  237. // add_update(new Tensor[] { variance_update }, inputs: true);
  238. }
  239. return output;
  240. }
  241. void _assign_new_value(IVariableV1 variable, Tensor value)
  242. {
  243. tf_with(ops.name_scope("AssignNewValue", null, new { variable, value, momentum }), scope =>
  244. {
  245. // var cm = ops.colocate_with(variable);
  246. variable.assign_lazy_load(value, name: scope);
  247. });
  248. }
  249. void _assign_moving_average(IVariableV1 variable, Tensor value, Tensor momentum)
  250. {
  251. tf_with(ops.name_scope("AssignMovingAvg", null, new { variable, value, momentum }), scope =>
  252. {
  253. // var cm = ops.colocate_with(variable);
  254. var decay = ops.convert_to_tensor(1.0f - momentum, name: "decay");
  255. var update_delta = (variable.AsTensor() - math_ops.cast(value, variable.dtype)) * decay;
  256. variable.assign_sub_lazy_load(update_delta, name: scope);
  257. });
  258. }
  259. }
  260. }