You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

LSTMCell.cs 10 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. using Serilog.Core;
  2. using System.Diagnostics;
  3. using Tensorflow.Common.Types;
  4. using Tensorflow.Keras.ArgsDefinition.Rnn;
  5. using Tensorflow.Keras.Engine;
  6. using Tensorflow.Keras.Saving;
  7. using Tensorflow.Keras.Utils;
  8. namespace Tensorflow.Keras.Layers.Rnn
  9. {
  10. /// <summary>
  11. /// Cell class for the LSTM layer.
  12. /// See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn)
  13. /// for details about the usage of RNN API.
  14. /// This class processes one step within the whole time sequence input, whereas
  15. /// `tf.keras.layer.LSTM` processes the whole sequence.
  16. /// </summary>
  17. public class LSTMCell : DropoutRNNCellMixin
  18. {
  19. LSTMCellArgs _args;
  20. IVariableV1 _kernel;
  21. IVariableV1 _recurrent_kernel;
  22. IInitializer _bias_initializer;
  23. IVariableV1 _bias;
  24. GeneralizedTensorShape _state_size;
  25. GeneralizedTensorShape _output_size;
  26. public override GeneralizedTensorShape StateSize => _state_size;
  27. public override GeneralizedTensorShape OutputSize => _output_size;
  28. public override bool IsTFRnnCell => true;
  29. public override bool SupportOptionalArgs => false;
  30. public LSTMCell(LSTMCellArgs args)
  31. : base(args)
  32. {
  33. _args = args;
  34. if (args.Units <= 0)
  35. {
  36. throw new ValueError(
  37. $"units must be a positive integer, got {args.Units}");
  38. }
  39. _args.Dropout = Math.Min(1f, Math.Max(0f, this._args.Dropout));
  40. _args.RecurrentDropout = Math.Min(1f, Math.Max(0f, this._args.RecurrentDropout));
  41. if (_args.RecurrentDropout != 0f && _args.Implementation != 1)
  42. {
  43. Debug.WriteLine("RNN `implementation=2` is not supported when `recurrent_dropout` is set." +
  44. "Using `implementation=1`.");
  45. _args.Implementation = 1;
  46. }
  47. _state_size = new GeneralizedTensorShape(_args.Units, 2);
  48. _output_size = new GeneralizedTensorShape(_args.Units);
  49. }
  50. public override void build(KerasShapesWrapper input_shape)
  51. {
  52. var single_shape = input_shape.ToSingleShape();
  53. var input_dim = single_shape[-1];
  54. _kernel = add_weight("kernel", (input_dim, _args.Units * 4),
  55. initializer: _args.KernelInitializer
  56. );
  57. _recurrent_kernel = add_weight("recurrent_kernel", (_args.Units, _args.Units * 4),
  58. initializer: _args.RecurrentInitializer
  59. );
  60. if (_args.UseBias)
  61. {
  62. if (_args.UnitForgetBias)
  63. {
  64. Tensor bias_initializer()
  65. {
  66. return keras.backend.concatenate(
  67. new Tensors(
  68. _args.BiasInitializer.Apply(new InitializerArgs(shape: (_args.Units))),
  69. tf.ones_initializer.Apply(new InitializerArgs(shape: (_args.Units))),
  70. _args.BiasInitializer.Apply(new InitializerArgs(shape: (_args.Units)))), axis: 0);
  71. }
  72. }
  73. else
  74. {
  75. _bias_initializer = _args.BiasInitializer;
  76. }
  77. _bias = add_weight("bias", (_args.Units * 4),
  78. initializer: _args.BiasInitializer);
  79. }
  80. built = true;
  81. }
  82. protected override Tensors Call(Tensors inputs, Tensors states = null, bool? training = null, IOptionalArgs? optional_args = null)
  83. {
  84. var h_tm1 = states[0]; // previous memory state
  85. var c_tm1 = states[1]; // previous carry state
  86. var dp_mask = get_dropout_mask_for_cell(inputs, training.Value, count: 4);
  87. var rec_dp_mask = get_recurrent_dropout_mask_for_cell(
  88. h_tm1, training.Value, count: 4);
  89. Tensor c;
  90. Tensor o;
  91. if (_args.Implementation == 1)
  92. {
  93. Tensor inputs_i;
  94. Tensor inputs_f;
  95. Tensor inputs_c;
  96. Tensor inputs_o;
  97. if (0f < _args.Dropout && _args.Dropout < 1f)
  98. {
  99. inputs_i = inputs * dp_mask[0];
  100. inputs_f = inputs * dp_mask[1];
  101. inputs_c = inputs * dp_mask[2];
  102. inputs_o = inputs * dp_mask[3];
  103. }
  104. else
  105. {
  106. inputs_i = inputs;
  107. inputs_f = inputs;
  108. inputs_c = inputs;
  109. inputs_o = inputs;
  110. }
  111. var k = tf.split(_kernel.AsTensor(), num_split: 4, axis: 1);
  112. Tensor k_i = k[0], k_f = k[1], k_c = k[2], k_o = k[3];
  113. var x_i = math_ops.matmul(inputs_i, k_i);
  114. var x_f = math_ops.matmul(inputs_f, k_f);
  115. var x_c = math_ops.matmul(inputs_c, k_c);
  116. var x_o = math_ops.matmul(inputs_o, k_o);
  117. if(_args.UseBias)
  118. {
  119. var b = tf.split(_bias.AsTensor(), num_split: 4, axis: 0);
  120. Tensor b_i = b[0], b_f = b[1], b_c = b[2], b_o = b[3];
  121. x_i = gen_nn_ops.bias_add(x_i, b_i);
  122. x_f = gen_nn_ops.bias_add(x_f, b_f);
  123. x_c = gen_nn_ops.bias_add(x_c, b_c);
  124. x_o = gen_nn_ops.bias_add(x_o, b_o);
  125. }
  126. Tensor h_tm1_i;
  127. Tensor h_tm1_f;
  128. Tensor h_tm1_c;
  129. Tensor h_tm1_o;
  130. if (0f < _args.RecurrentDropout && _args.RecurrentDropout < 1f)
  131. {
  132. h_tm1_i = h_tm1 * rec_dp_mask[0];
  133. h_tm1_f = h_tm1 * rec_dp_mask[1];
  134. h_tm1_c = h_tm1 * rec_dp_mask[2];
  135. h_tm1_o = h_tm1 * rec_dp_mask[3];
  136. }
  137. else
  138. {
  139. h_tm1_i = h_tm1;
  140. h_tm1_f = h_tm1;
  141. h_tm1_c = h_tm1;
  142. h_tm1_o = h_tm1;
  143. }
  144. var x = new Tensor[] { x_i, x_f, x_c, x_o };
  145. var h_tm1_array = new Tensor[] { h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o };
  146. (c, o) = _compute_carry_and_output(x, h_tm1_array, c_tm1);
  147. }
  148. else
  149. {
  150. if (0f < _args.Dropout && _args.Dropout < 1f)
  151. inputs = inputs * dp_mask[0];
  152. var z = math_ops.matmul(inputs, _kernel.AsTensor());
  153. z += math_ops.matmul(h_tm1, _recurrent_kernel.AsTensor());
  154. if (_args.UseBias)
  155. {
  156. z = tf.nn.bias_add(z, _bias);
  157. }
  158. var z_array = tf.split(z, num_split: 4, axis: 1);
  159. (c, o) = _compute_carry_and_output_fused(z_array, c_tm1);
  160. }
  161. var h = o * _args.Activation.Apply(c);
  162. // 这里是因为 Tensors 类初始化的时候会把第一个元素之后的元素打包成一个数组
  163. return new Tensors(h, h, c);
  164. }
  165. /// <summary>
  166. /// Computes carry and output using split kernels.
  167. /// </summary>
  168. /// <param name="x"></param>
  169. /// <param name="h_tm1"></param>
  170. /// <param name="c_tm1"></param>
  171. /// <returns></returns>
  172. /// <exception cref="NotImplementedException"></exception>
  173. public Tensors _compute_carry_and_output(Tensor[] x, Tensor[] h_tm1, Tensor c_tm1)
  174. {
  175. Tensor x_i = x[0], x_f = x[1], x_c = x[2], x_o = x[3];
  176. Tensor h_tm1_i = h_tm1[0], h_tm1_f = h_tm1[1], h_tm1_c = h_tm1[2],
  177. h_tm1_o = h_tm1[3];
  178. var _recurrent_kernel_tensor = _recurrent_kernel.AsTensor();
  179. var startIndex = _recurrent_kernel_tensor.shape[0];
  180. var endIndex = _recurrent_kernel_tensor.shape[1];
  181. var _recurrent_kernel_slice = tf.slice(_recurrent_kernel_tensor,
  182. new[] { 0, 0 }, new[] { startIndex, _args.Units });
  183. var i = _args.RecurrentActivation.Apply(
  184. x_i + math_ops.matmul(h_tm1_i, _recurrent_kernel_slice));
  185. _recurrent_kernel_slice = tf.slice(_recurrent_kernel_tensor,
  186. new[] { 0, _args.Units }, new[] { startIndex, _args.Units * 2});
  187. var f = _args.RecurrentActivation.Apply(
  188. x_f + math_ops.matmul(h_tm1_f, _recurrent_kernel_slice));
  189. _recurrent_kernel_slice = tf.slice(_recurrent_kernel_tensor,
  190. new[] { 0, _args.Units * 2 }, new[] { startIndex, _args.Units * 3 });
  191. var c = f * c_tm1 + i * _args.Activation.Apply(
  192. x_c + math_ops.matmul(h_tm1_c, _recurrent_kernel_slice));
  193. _recurrent_kernel_slice = tf.slice(_recurrent_kernel_tensor,
  194. new[] { 0, _args.Units * 3 }, new[] { startIndex, endIndex });
  195. var o = _args.RecurrentActivation.Apply(
  196. x_o + math_ops.matmul(h_tm1_o, _recurrent_kernel_slice));
  197. return new Tensors(c, o);
  198. }
  199. /// <summary>
  200. /// Computes carry and output using fused kernels.
  201. /// </summary>
  202. /// <param name="z"></param>
  203. /// <param name="c_tm1"></param>
  204. /// <returns></returns>
  205. public Tensors _compute_carry_and_output_fused(Tensor[] z, Tensor c_tm1)
  206. {
  207. Tensor z0 = z[0], z1 = z[1], z2 = z[2], z3 = z[3];
  208. var i = _args.RecurrentActivation.Apply(z0);
  209. var f = _args.RecurrentActivation.Apply(z1);
  210. var c = f * c_tm1 + i * _args.RecurrentActivation.Apply(z2);
  211. var o = _args.RecurrentActivation.Apply(z3);
  212. return new Tensors(c, o);
  213. }
  214. public Tensors get_initial_state(Tensors inputs = null, long? batch_size = null, TF_DataType? dtype = null)
  215. {
  216. return RnnUtils.generate_zero_filled_state_for_cell(this, inputs, batch_size.Value, dtype.Value);
  217. }
  218. }
  219. }