You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

SimpleRNNCell.cs 4.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Text;
  4. using Tensorflow.Keras.ArgsDefinition.Rnn;
  5. using Tensorflow.Keras.Engine;
  6. using Tensorflow.Keras.Saving;
  7. using Tensorflow.Common.Types;
  8. using Tensorflow.Common.Extensions;
  9. using Tensorflow.Keras.Utils;
  10. namespace Tensorflow.Keras.Layers.Rnn
  11. {
  12. /// <summary>
  13. /// Cell class for SimpleRNN.
  14. /// See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn)
  15. /// for details about the usage of RNN API.
  16. /// This class processes one step within the whole time sequence input, whereas
  17. /// `tf.keras.layer.SimpleRNN` processes the whole sequence.
  18. /// </summary>
  19. public class SimpleRNNCell : DropoutRNNCellMixin
  20. {
  21. SimpleRNNCellArgs _args;
  22. IVariableV1 _kernel;
  23. IVariableV1 _recurrent_kernel;
  24. IVariableV1 _bias;
  25. GeneralizedTensorShape _state_size;
  26. GeneralizedTensorShape _output_size;
  27. public override GeneralizedTensorShape StateSize => _state_size;
  28. public override GeneralizedTensorShape OutputSize => _output_size;
  29. public override bool IsTFRnnCell => true;
  30. public override bool SupportOptionalArgs => false;
  31. public SimpleRNNCell(SimpleRNNCellArgs args) : base(args)
  32. {
  33. this._args = args;
  34. if (args.Units <= 0)
  35. {
  36. throw new ValueError(
  37. $"units must be a positive integer, got {args.Units}");
  38. }
  39. this._args.Dropout = Math.Min(1f, Math.Max(0f, this._args.Dropout));
  40. this._args.RecurrentDropout = Math.Min(1f, Math.Max(0f, this._args.RecurrentDropout));
  41. _state_size = new GeneralizedTensorShape(args.Units);
  42. _output_size = new GeneralizedTensorShape(args.Units);
  43. }
  44. public override void build(KerasShapesWrapper input_shape)
  45. {
  46. // TODO(Rinne): add the cache.
  47. var single_shape = input_shape.ToSingleShape();
  48. var input_dim = single_shape[-1];
  49. _kernel = add_weight("kernel", (single_shape[-1], _args.Units),
  50. initializer: _args.KernelInitializer
  51. );
  52. _recurrent_kernel = add_weight("recurrent_kernel", (_args.Units, _args.Units),
  53. initializer: _args.RecurrentInitializer
  54. );
  55. if (_args.UseBias)
  56. {
  57. _bias = add_weight("bias", (_args.Units),
  58. initializer: _args.BiasInitializer
  59. );
  60. }
  61. built = true;
  62. }
  63. // TODO(Rinne): revise the trining param (with refactoring of the framework)
  64. protected override Tensors Call(Tensors inputs, Tensors states = null, bool? training = null, IOptionalArgs? optional_args = null)
  65. {
  66. // TODO(Rinne): check if it will have multiple tensors when not nested.
  67. Tensors prev_output = Nest.IsNested(states) ? new Tensors(states[0]) : states;
  68. var dp_mask = get_dropout_mask_for_cell(inputs, training.Value);
  69. var rec_dp_mask = get_recurrent_dropout_mask_for_cell(prev_output, training.Value);
  70. Tensor h;
  71. var ranks = inputs.rank;
  72. if (dp_mask != null)
  73. {
  74. h = math_ops.matmul(math_ops.multiply(inputs.Single, dp_mask.Single), _kernel.AsTensor());
  75. }
  76. else
  77. {
  78. h = math_ops.matmul(inputs, _kernel.AsTensor());
  79. }
  80. if (_bias != null)
  81. {
  82. h = tf.nn.bias_add(h, _bias);
  83. }
  84. if (rec_dp_mask != null)
  85. {
  86. prev_output = math_ops.multiply(prev_output, rec_dp_mask);
  87. }
  88. Tensor output = h + math_ops.matmul(prev_output, _recurrent_kernel.AsTensor());
  89. if (_args.Activation != null)
  90. {
  91. output = _args.Activation.Apply(output);
  92. }
  93. if (Nest.IsNested(states))
  94. {
  95. return new Nest<Tensor>(new List<Nest<Tensor>> {
  96. new Nest<Tensor>(new List<Nest<Tensor>> { new Nest<Tensor>(output) }), new Nest<Tensor>(output) })
  97. .ToTensors();
  98. }
  99. else
  100. {
  101. return new Tensors(output, output);
  102. }
  103. }
  104. public Tensors get_initial_state(Tensors inputs = null, long? batch_size = null, TF_DataType? dtype = null)
  105. {
  106. return RnnUtils.generate_zero_filled_state_for_cell(this, inputs, batch_size.Value, dtype.Value);
  107. }
  108. }
  109. }