You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

LSTMCell.cs 10 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. using Newtonsoft.Json;
  2. using Serilog.Core;
  3. using System.Diagnostics;
  4. using Tensorflow.Common.Extensions;
  5. using Tensorflow.Common.Types;
  6. using Tensorflow.Keras.ArgsDefinition;
  7. using Tensorflow.Keras.Engine;
  8. using Tensorflow.Keras.Saving;
  9. using Tensorflow.Keras.Utils;
  10. namespace Tensorflow.Keras.Layers
  11. {
  12. /// <summary>
  13. /// Cell class for the LSTM layer.
  14. /// See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn)
  15. /// for details about the usage of RNN API.
  16. /// This class processes one step within the whole time sequence input, whereas
  17. /// `tf.keras.layer.LSTM` processes the whole sequence.
  18. /// </summary>
  19. public class LSTMCell : DropoutRNNCellMixin
  20. {
  21. LSTMCellArgs _args;
  22. IVariableV1 _kernel;
  23. IVariableV1 _recurrent_kernel;
  24. IInitializer _bias_initializer;
  25. IVariableV1 _bias;
  26. INestStructure<long> _state_size;
  27. INestStructure<long> _output_size;
  28. public override INestStructure<long> StateSize => _state_size;
  29. public override INestStructure<long> OutputSize => _output_size;
  30. public override bool SupportOptionalArgs => false;
  31. public LSTMCell(LSTMCellArgs args)
  32. : base(args)
  33. {
  34. _args = args;
  35. if (args.Units <= 0)
  36. {
  37. throw new ValueError(
  38. $"units must be a positive integer, got {args.Units}");
  39. }
  40. _args.Dropout = Math.Min(1f, Math.Max(0f, this._args.Dropout));
  41. _args.RecurrentDropout = Math.Min(1f, Math.Max(0f, this._args.RecurrentDropout));
  42. if (_args.RecurrentDropout != 0f && _args.Implementation != 1)
  43. {
  44. Debug.WriteLine("RNN `implementation=2` is not supported when `recurrent_dropout` is set." +
  45. "Using `implementation=1`.");
  46. _args.Implementation = 1;
  47. }
  48. _state_size = new NestList<long>(_args.Units, _args.Units);
  49. _output_size = new NestNode<long>(_args.Units);
  50. }
  51. public override void build(KerasShapesWrapper input_shape)
  52. {
  53. base.build(input_shape);
  54. var single_shape = input_shape.ToSingleShape();
  55. var input_dim = single_shape[-1];
  56. _kernel = add_weight("kernel", (input_dim, _args.Units * 4),
  57. initializer: _args.KernelInitializer
  58. );
  59. _recurrent_kernel = add_weight("recurrent_kernel", (_args.Units, _args.Units * 4),
  60. initializer: _args.RecurrentInitializer
  61. );
  62. if (_args.UseBias)
  63. {
  64. if (_args.UnitForgetBias)
  65. {
  66. Tensor bias_initializer()
  67. {
  68. return keras.backend.concatenate(
  69. new Tensors(
  70. _args.BiasInitializer.Apply(new InitializerArgs(shape: (_args.Units))),
  71. tf.ones_initializer.Apply(new InitializerArgs(shape: (_args.Units))),
  72. _args.BiasInitializer.Apply(new InitializerArgs(shape: (_args.Units)))), axis: 0);
  73. }
  74. }
  75. else
  76. {
  77. _bias_initializer = _args.BiasInitializer;
  78. }
  79. _bias = add_weight("bias", (_args.Units * 4),
  80. initializer: _bias_initializer
  81. );
  82. }
  83. built = true;
  84. }
  85. protected override Tensors Call(Tensors inputs, Tensors states = null, bool? training = null, IOptionalArgs? optional_args = null)
  86. {
  87. var h_tm1 = states[0]; // previous memory state
  88. var c_tm1 = states[1]; // previous carry state
  89. var dp_mask = get_dropout_mask_for_cell(inputs, training.Value, count: 4);
  90. var rec_dp_mask = get_recurrent_dropout_mask_for_cell(
  91. h_tm1, training.Value, count: 4);
  92. Tensor c;
  93. Tensor o;
  94. if (_args.Implementation == 1)
  95. {
  96. Tensor inputs_i;
  97. Tensor inputs_f;
  98. Tensor inputs_c;
  99. Tensor inputs_o;
  100. if (0f < _args.Dropout && _args.Dropout < 1f)
  101. {
  102. inputs_i = inputs * dp_mask[0];
  103. inputs_f = inputs * dp_mask[1];
  104. inputs_c = inputs * dp_mask[2];
  105. inputs_o = inputs * dp_mask[3];
  106. }
  107. else
  108. {
  109. inputs_i = inputs;
  110. inputs_f = inputs;
  111. inputs_c = inputs;
  112. inputs_o = inputs;
  113. }
  114. var k = tf.split(_kernel.AsTensor(), num_split: 4, axis: 1);
  115. Tensor k_i = k[0], k_f = k[1], k_c = k[2], k_o = k[3];
  116. var x_i = math_ops.matmul(inputs_i, k_i);
  117. var x_f = math_ops.matmul(inputs_f, k_f);
  118. var x_c = math_ops.matmul(inputs_c, k_c);
  119. var x_o = math_ops.matmul(inputs_o, k_o);
  120. if (_args.UseBias)
  121. {
  122. var b = tf.split(_bias.AsTensor(), num_split: 4, axis: 0);
  123. Tensor b_i = b[0], b_f = b[1], b_c = b[2], b_o = b[3];
  124. x_i = gen_nn_ops.bias_add(x_i, b_i);
  125. x_f = gen_nn_ops.bias_add(x_f, b_f);
  126. x_c = gen_nn_ops.bias_add(x_c, b_c);
  127. x_o = gen_nn_ops.bias_add(x_o, b_o);
  128. }
  129. Tensor h_tm1_i;
  130. Tensor h_tm1_f;
  131. Tensor h_tm1_c;
  132. Tensor h_tm1_o;
  133. if (0f < _args.RecurrentDropout && _args.RecurrentDropout < 1f)
  134. {
  135. h_tm1_i = h_tm1 * rec_dp_mask[0];
  136. h_tm1_f = h_tm1 * rec_dp_mask[1];
  137. h_tm1_c = h_tm1 * rec_dp_mask[2];
  138. h_tm1_o = h_tm1 * rec_dp_mask[3];
  139. }
  140. else
  141. {
  142. h_tm1_i = h_tm1;
  143. h_tm1_f = h_tm1;
  144. h_tm1_c = h_tm1;
  145. h_tm1_o = h_tm1;
  146. }
  147. var x = new Tensor[] { x_i, x_f, x_c, x_o };
  148. var h_tm1_array = new Tensor[] { h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o };
  149. (c, o) = _compute_carry_and_output(x, h_tm1_array, c_tm1);
  150. }
  151. else
  152. {
  153. if (0f < _args.Dropout && _args.Dropout < 1f)
  154. inputs = inputs * dp_mask[0];
  155. var z = math_ops.matmul(inputs, _kernel.AsTensor());
  156. z += math_ops.matmul(h_tm1, _recurrent_kernel.AsTensor());
  157. if (_args.UseBias)
  158. {
  159. z = tf.nn.bias_add(z, _bias);
  160. }
  161. var z_array = tf.split(z, num_split: 4, axis: 1);
  162. (c, o) = _compute_carry_and_output_fused(z_array, c_tm1);
  163. }
  164. var h = o * _args.Activation.Apply(c);
  165. // 这里是因为 Tensors 类初始化的时候会把第一个元素之后的元素打包成一个数组
  166. return new Nest<Tensor>(new INestStructure<Tensor>[] { new NestNode<Tensor>(h), new NestList<Tensor>(h, c) }).ToTensors();
  167. }
  168. /// <summary>
  169. /// Computes carry and output using split kernels.
  170. /// </summary>
  171. /// <param name="x"></param>
  172. /// <param name="h_tm1"></param>
  173. /// <param name="c_tm1"></param>
  174. /// <returns></returns>
  175. /// <exception cref="NotImplementedException"></exception>
  176. public Tensors _compute_carry_and_output(Tensor[] x, Tensor[] h_tm1, Tensor c_tm1)
  177. {
  178. Tensor x_i = x[0], x_f = x[1], x_c = x[2], x_o = x[3];
  179. Tensor h_tm1_i = h_tm1[0], h_tm1_f = h_tm1[1], h_tm1_c = h_tm1[2],
  180. h_tm1_o = h_tm1[3];
  181. var _recurrent_kernel_tensor = _recurrent_kernel.AsTensor();
  182. int startIndex = (int)_recurrent_kernel_tensor.shape[0];
  183. var _recurrent_kernel_slice = tf.slice(_recurrent_kernel_tensor,
  184. new[] { 0, 0 }, new[] { startIndex, _args.Units });
  185. var i = _args.RecurrentActivation.Apply(
  186. x_i + math_ops.matmul(h_tm1_i, _recurrent_kernel_slice));
  187. _recurrent_kernel_slice = tf.slice(_recurrent_kernel_tensor,
  188. new[] { 0, _args.Units }, new[] { startIndex, _args.Units});
  189. var f = _args.RecurrentActivation.Apply(
  190. x_f + math_ops.matmul(h_tm1_f, _recurrent_kernel_slice));
  191. _recurrent_kernel_slice = tf.slice(_recurrent_kernel_tensor,
  192. new[] { 0, _args.Units * 2 }, new[] { startIndex, _args.Units });
  193. var c = f * c_tm1 + i * _args.Activation.Apply(
  194. x_c + math_ops.matmul(h_tm1_c, _recurrent_kernel_slice));
  195. _recurrent_kernel_slice = tf.slice(_recurrent_kernel_tensor,
  196. new[] { 0, _args.Units * 3 }, new[] { startIndex, _args.Units });
  197. var o = _args.Activation.Apply(
  198. x_o + math_ops.matmul(h_tm1_o, _recurrent_kernel_slice));
  199. return new Tensors(c, o);
  200. }
  201. /// <summary>
  202. /// Computes carry and output using fused kernels.
  203. /// </summary>
  204. /// <param name="z"></param>
  205. /// <param name="c_tm1"></param>
  206. /// <returns></returns>
  207. public Tensors _compute_carry_and_output_fused(Tensor[] z, Tensor c_tm1)
  208. {
  209. Tensor z0 = z[0], z1 = z[1], z2 = z[2], z3 = z[3];
  210. var i = _args.RecurrentActivation.Apply(z0);
  211. var f = _args.RecurrentActivation.Apply(z1);
  212. var c = f * c_tm1 + i * _args.Activation.Apply(z2);
  213. var o = _args.RecurrentActivation.Apply(z3);
  214. return new Tensors(c, o);
  215. }
  216. }
  217. }