You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

DropoutRNNCellMixin.cs 3.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. using System;
  2. using System.Collections.Generic;
  3. <<<<<<< HEAD
  4. using Tensorflow.Keras.ArgsDefinition;
  5. using Tensorflow.Keras.ArgsDefinition.Rnn;
  6. using Tensorflow.Keras.Engine;
  7. namespace Tensorflow.Keras.Layers.Rnn
  8. {
  9. public class DropoutRNNCellMixin
  10. {
  11. public float dropout;
  12. public float recurrent_dropout;
  13. // Get the dropout mask for RNN cell's input.
  14. public Tensors? get_dropout_maskcell_for_cell(Tensors input, bool training, int count = 1)
  15. =======
  16. using System.Text;
  17. using Tensorflow.Common.Types;
  18. using Tensorflow.Keras.ArgsDefinition;
  19. using Tensorflow.Keras.Engine;
  20. namespace Tensorflow.Keras.Layers.Rnn
  21. {
  22. public abstract class DropoutRNNCellMixin: RnnCellBase
  23. {
  24. public float dropout;
  25. public float recurrent_dropout;
  26. // TODO(Rinne): deal with cache.
  27. public DropoutRNNCellMixin(LayerArgs args): base(args)
  28. {
  29. }
  30. protected void _create_non_trackable_mask_cache()
  31. {
  32. }
  33. public void reset_dropout_mask()
  34. {
  35. }
  36. public void reset_recurrent_dropout_mask()
  37. {
  38. }
  39. public Tensors? get_dropout_mask_for_cell(Tensors input, bool training, int count = 1)
  40. >>>>>>> 90a65d7d98b92f26574ac32392ed802a57d4d2c8
  41. {
  42. if (dropout == 0f)
  43. return null;
  44. return _generate_dropout_mask(
  45. tf.ones_like(input),
  46. dropout,
  47. training,
  48. count);
  49. }
  50. // Get the recurrent dropout mask for RNN cell.
  51. <<<<<<< HEAD
  52. public Tensors? get_recurrent_dropout_maskcell_for_cell(Tensors input, bool training, int count = 1)
  53. =======
  54. public Tensors? get_recurrent_dropout_mask_for_cell(Tensors input, bool training, int count = 1)
  55. >>>>>>> 90a65d7d98b92f26574ac32392ed802a57d4d2c8
  56. {
  57. if (dropout == 0f)
  58. return null;
  59. return _generate_dropout_mask(
  60. tf.ones_like(input),
  61. recurrent_dropout,
  62. training,
  63. count);
  64. }
  65. public Tensors _create_dropout_mask(Tensors input, bool training, int count = 1)
  66. {
  67. return _generate_dropout_mask(
  68. tf.ones_like(input),
  69. dropout,
  70. training,
  71. count);
  72. }
  73. public Tensors _create_recurrent_dropout_mask(Tensors input, bool training, int count = 1)
  74. {
  75. return _generate_dropout_mask(
  76. tf.ones_like(input),
  77. recurrent_dropout,
  78. training,
  79. count);
  80. }
  81. public Tensors _generate_dropout_mask(Tensor ones, float rate, bool training, int count = 1)
  82. {
  83. Tensors dropped_inputs()
  84. {
  85. DropoutArgs args = new DropoutArgs();
  86. args.Rate = rate;
  87. var DropoutLayer = new Dropout(args);
  88. var mask = DropoutLayer.Apply(ones, training: training);
  89. return mask;
  90. }
  91. if (count > 1)
  92. {
  93. Tensors results = new Tensors();
  94. for (int i = 0; i < count; i++)
  95. {
  96. results.Add(dropped_inputs());
  97. }
  98. return results;
  99. }
  100. return dropped_inputs();
  101. }
  102. }
  103. <<<<<<< HEAD
  104. =======
  105. >>>>>>> 90a65d7d98b92f26574ac32392ed802a57d4d2c8
  106. }