You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

GRUCell.cs 12 kB


  1. using System;
  2. using System.Collections.Generic;
  3. using System.Diagnostics;
  4. using System.Text;
  5. using Tensorflow.Keras.ArgsDefinition;
  6. using Tensorflow.Keras.ArgsDefinition.Rnn;
  7. using Tensorflow.Common.Extensions;
  8. using Tensorflow.Common.Types;
  9. using Tensorflow.Keras.Saving;
  10. namespace Tensorflow.Keras.Layers.Rnn
  11. {
  12. /// <summary>
  13. /// Cell class for the GRU layer.
  14. /// </summary>
  15. public class GRUCell : DropoutRNNCellMixin
  16. {
  17. GRUCellArgs _args;
  18. IVariableV1 _kernel;
  19. IVariableV1 _recurrent_kernel;
  20. IInitializer _bias_initializer;
  21. IVariableV1 _bias;
  22. INestStructure<long> _state_size;
  23. INestStructure<long> _output_size;
  24. int Units;
  25. public override INestStructure<long> StateSize => _state_size;
  26. public override INestStructure<long> OutputSize => _output_size;
  27. public override bool SupportOptionalArgs => false;
  28. public GRUCell(GRUCellArgs args) : base(args)
  29. {
  30. _args = args;
  31. if (_args.Units <= 0)
  32. {
  33. throw new ValueError(
  34. $"units must be a positive integer, got {args.Units}");
  35. }
  36. _args.Dropout = Math.Min(1f, Math.Max(0f, _args.Dropout));
  37. _args.RecurrentDropout = Math.Min(1f, Math.Max(0f, this._args.RecurrentDropout));
  38. if (_args.RecurrentDropout != 0f && _args.Implementation != 1)
  39. {
  40. Debug.WriteLine("RNN `implementation=2` is not supported when `recurrent_dropout` is set." +
  41. "Using `implementation=1`.");
  42. _args.Implementation = 1;
  43. }
  44. Units = _args.Units;
  45. _state_size = new NestList<long>(Units);
  46. _output_size = new NestNode<long>(Units);
  47. }
  48. public override void build(KerasShapesWrapper input_shape)
  49. {
  50. //base.build(input_shape);
  51. var single_shape = input_shape.ToSingleShape();
  52. var input_dim = single_shape[-1];
  53. _kernel = add_weight("kernel", (input_dim, _args.Units * 3),
  54. initializer: _args.KernelInitializer
  55. );
  56. _recurrent_kernel = add_weight("recurrent_kernel", (Units, Units * 3),
  57. initializer: _args.RecurrentInitializer
  58. );
  59. if (_args.UseBias)
  60. {
  61. Shape bias_shape;
  62. if (!_args.ResetAfter)
  63. {
  64. bias_shape = new Shape(3 * Units);
  65. }
  66. else
  67. {
  68. bias_shape = (2, 3 * Units);
  69. }
  70. _bias = add_weight("bias", bias_shape,
  71. initializer: _bias_initializer
  72. );
  73. }
  74. built = true;
  75. }
  76. protected override Tensors Call(Tensors inputs, Tensors states = null, bool? training = null, IOptionalArgs? optional_args = null)
  77. {
  78. var h_tm1 = states.IsNested() ? states[0] : states.Single();
  79. var dp_mask = get_dropout_mask_for_cell(inputs, training.Value, count: 3);
  80. var rec_dp_mask = get_recurrent_dropout_mask_for_cell(h_tm1, training.Value, count: 3);
  81. IVariableV1 input_bias = _bias;
  82. IVariableV1 recurrent_bias = _bias;
  83. if (_args.UseBias)
  84. {
  85. if (!_args.ResetAfter)
  86. {
  87. input_bias = _bias;
  88. recurrent_bias = null;
  89. }
  90. else
  91. {
  92. input_bias = tf.Variable(tf.unstack(_bias.AsTensor())[0]);
  93. recurrent_bias = tf.Variable(tf.unstack(_bias.AsTensor())[1]);
  94. }
  95. }
  96. Tensor hh;
  97. Tensor z;
  98. if ( _args.Implementation == 1)
  99. {
  100. Tensor inputs_z;
  101. Tensor inputs_r;
  102. Tensor inputs_h;
  103. if (0f < _args.Dropout && _args.Dropout < 1f)
  104. {
  105. inputs_z = inputs * dp_mask[0];
  106. inputs_r = inputs * dp_mask[1];
  107. inputs_h = inputs * dp_mask[2];
  108. }
  109. else
  110. {
  111. inputs_z = inputs.Single();
  112. inputs_r = inputs.Single();
  113. inputs_h = inputs.Single();
  114. }
  115. int startIndex = (int)_kernel.AsTensor().shape[0];
  116. var _kernel_slice = tf.slice(_kernel.AsTensor(),
  117. new[] { 0, 0 }, new[] { startIndex, Units });
  118. var x_z = math_ops.matmul(inputs_z, _kernel_slice);
  119. _kernel_slice = tf.slice(_kernel.AsTensor(),
  120. new[] { 0, Units }, new[] { Units, Units});
  121. var x_r = math_ops.matmul(
  122. inputs_r, _kernel_slice);
  123. int endIndex = (int)_kernel.AsTensor().shape[1];
  124. _kernel_slice = tf.slice(_kernel.AsTensor(),
  125. new[] { 0, Units * 2 }, new[] { startIndex, endIndex - Units * 2 });
  126. var x_h = math_ops.matmul(inputs_h, _kernel_slice);
  127. if(_args.UseBias)
  128. {
  129. x_z = tf.nn.bias_add(
  130. x_z, tf.Variable(input_bias.AsTensor()[$":{Units}"]));
  131. x_r = tf.nn.bias_add(
  132. x_r, tf.Variable(input_bias.AsTensor()[$"{Units}:{Units * 2}"]));
  133. x_h = tf.nn.bias_add(
  134. x_h, tf.Variable(input_bias.AsTensor()[$"{Units * 2}:"]));
  135. }
  136. Tensor h_tm1_z;
  137. Tensor h_tm1_r;
  138. Tensor h_tm1_h;
  139. if (0f < _args.RecurrentDropout && _args.RecurrentDropout < 1f)
  140. {
  141. h_tm1_z = h_tm1 * rec_dp_mask[0];
  142. h_tm1_r = h_tm1 * rec_dp_mask[1];
  143. h_tm1_h = h_tm1 * rec_dp_mask[2];
  144. }
  145. else
  146. {
  147. h_tm1_z = h_tm1;
  148. h_tm1_r = h_tm1;
  149. h_tm1_h = h_tm1;
  150. }
  151. startIndex = (int)_recurrent_kernel.AsTensor().shape[0];
  152. var _recurrent_kernel_slice = tf.slice(_recurrent_kernel.AsTensor(),
  153. new[] { 0, 0 }, new[] { startIndex, Units });
  154. var recurrent_z = math_ops.matmul(
  155. h_tm1_z, _recurrent_kernel_slice);
  156. _recurrent_kernel_slice = tf.slice(_recurrent_kernel.AsTensor(),
  157. new[] { 0, Units }, new[] { startIndex, Units});
  158. var recurrent_r = math_ops.matmul(
  159. h_tm1_r, _recurrent_kernel_slice);
  160. if(_args.ResetAfter && _args.UseBias)
  161. {
  162. recurrent_z = tf.nn.bias_add(
  163. recurrent_z, tf.Variable(recurrent_bias.AsTensor()[$":{Units}"]));
  164. recurrent_r = tf.nn.bias_add(
  165. recurrent_r, tf.Variable(recurrent_bias.AsTensor()[$"{Units}: {Units * 2}"]));
  166. }
  167. z = _args.RecurrentActivation.Apply(x_z + recurrent_z);
  168. var r = _args.RecurrentActivation.Apply(x_r + recurrent_r);
  169. Tensor recurrent_h;
  170. if (_args.ResetAfter)
  171. {
  172. endIndex = (int)_recurrent_kernel.AsTensor().shape[1];
  173. _recurrent_kernel_slice = tf.slice(_recurrent_kernel.AsTensor(),
  174. new[] { 0, Units * 2 }, new[] { startIndex, endIndex - Units * 2 });
  175. recurrent_h = math_ops.matmul(
  176. h_tm1_h, _recurrent_kernel_slice);
  177. if(_args.UseBias)
  178. {
  179. recurrent_h = tf.nn.bias_add(
  180. recurrent_h, tf.Variable(recurrent_bias.AsTensor()[$"{Units * 2}:"]));
  181. }
  182. recurrent_h *= r;
  183. }
  184. else
  185. {
  186. _recurrent_kernel_slice = tf.slice(_recurrent_kernel.AsTensor(),
  187. new[] { 0, Units * 2 }, new[] { startIndex, endIndex - Units * 2 });
  188. recurrent_h = math_ops.matmul(
  189. r * h_tm1_h, _recurrent_kernel_slice);
  190. }
  191. hh = _args.Activation.Apply(x_h + recurrent_h);
  192. }
  193. else
  194. {
  195. if (0f < _args.Dropout && _args.Dropout < 1f)
  196. {
  197. inputs = inputs * dp_mask[0];
  198. }
  199. var matrix_x = math_ops.matmul(inputs, _kernel.AsTensor());
  200. if(_args.UseBias)
  201. {
  202. matrix_x = tf.nn.bias_add(matrix_x, input_bias);
  203. }
  204. var matrix_x_spilted = tf.split(matrix_x, 3, axis: -1);
  205. var x_z = matrix_x_spilted[0];
  206. var x_r = matrix_x_spilted[1];
  207. var x_h = matrix_x_spilted[2];
  208. Tensor matrix_inner;
  209. if (_args.ResetAfter)
  210. {
  211. matrix_inner = math_ops.matmul(h_tm1, _recurrent_kernel.AsTensor());
  212. if ( _args.UseBias)
  213. {
  214. matrix_inner = tf.nn.bias_add(
  215. matrix_inner, recurrent_bias);
  216. }
  217. }
  218. else
  219. {
  220. var startIndex = (int)_recurrent_kernel.AsTensor().shape[0];
  221. var _recurrent_kernel_slice = tf.slice(_recurrent_kernel.AsTensor(),
  222. new[] { 0, 0 }, new[] { startIndex, Units * 2 });
  223. matrix_inner = math_ops.matmul(
  224. h_tm1, _recurrent_kernel_slice);
  225. }
  226. var matrix_inner_splitted = tf.split(matrix_inner, new int[] {Units, Units, -1}, axis:-1);
  227. var recurrent_z = matrix_inner_splitted[0];
  228. var recurrent_r = matrix_inner_splitted[0];
  229. var recurrent_h = matrix_inner_splitted[0];
  230. z = _args.RecurrentActivation.Apply(x_z + recurrent_z);
  231. var r = _args.RecurrentActivation.Apply(x_r + recurrent_r);
  232. if(_args.ResetAfter)
  233. {
  234. recurrent_h = r * recurrent_h;
  235. }
  236. else
  237. {
  238. var startIndex = (int)_recurrent_kernel.AsTensor().shape[0];
  239. var endIndex = (int)_recurrent_kernel.AsTensor().shape[1];
  240. var _recurrent_kernel_slice = tf.slice(_recurrent_kernel.AsTensor(),
  241. new[] { 0, 2*Units }, new[] { startIndex, endIndex - 2 * Units });
  242. recurrent_h = math_ops.matmul(
  243. r * h_tm1, _recurrent_kernel_slice);
  244. }
  245. hh = _args.Activation.Apply(x_h + recurrent_h);
  246. }
  247. var h = z * h_tm1 + (1 - z) * hh;
  248. if (states.IsNested())
  249. {
  250. var new_state = new NestList<Tensor>(h);
  251. return new Nest<Tensor>(new INestStructure<Tensor>[] { new NestNode<Tensor>(h), new_state }).ToTensors();
  252. }
  253. else
  254. {
  255. return new Nest<Tensor>(new INestStructure<Tensor>[] { new NestNode<Tensor>(h), new NestNode<Tensor>(h)}).ToTensors();
  256. }
  257. }
  258. }
  259. }