You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

DropoutRNNCellMixin.cs 2.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Text;
  4. using Tensorflow.Common.Types;
  5. using Tensorflow.Keras.ArgsDefinition;
  6. using Tensorflow.Keras.Engine;
  7. namespace Tensorflow.Keras.Layers.Rnn
  8. {
  9. public abstract class DropoutRNNCellMixin: RnnCellBase
  10. {
  11. public float dropout;
  12. public float recurrent_dropout;
  13. // TODO(Rinne): deal with cache.
  14. public DropoutRNNCellMixin(LayerArgs args): base(args)
  15. {
  16. }
  17. protected void _create_non_trackable_mask_cache()
  18. {
  19. }
  20. public void reset_dropout_mask()
  21. {
  22. }
  23. public void reset_recurrent_dropout_mask()
  24. {
  25. }
  26. public Tensors? get_dropout_maskcell_for_cell(Tensors input, bool training, int count = 1)
  27. {
  28. if (dropout == 0f)
  29. return null;
  30. return _generate_dropout_mask(
  31. tf.ones_like(input),
  32. dropout,
  33. training,
  34. count);
  35. }
  36. // Get the recurrent dropout mask for RNN cell.
  37. public Tensors? get_recurrent_dropout_maskcell_for_cell(Tensors input, bool training, int count = 1)
  38. {
  39. if (dropout == 0f)
  40. return null;
  41. return _generate_dropout_mask(
  42. tf.ones_like(input),
  43. recurrent_dropout,
  44. training,
  45. count);
  46. }
  47. public Tensors _create_dropout_mask(Tensors input, bool training, int count = 1)
  48. {
  49. return _generate_dropout_mask(
  50. tf.ones_like(input),
  51. dropout,
  52. training,
  53. count);
  54. }
  55. public Tensors _create_recurrent_dropout_mask(Tensors input, bool training, int count = 1)
  56. {
  57. return _generate_dropout_mask(
  58. tf.ones_like(input),
  59. recurrent_dropout,
  60. training,
  61. count);
  62. }
  63. public Tensors _generate_dropout_mask(Tensor ones, float rate, bool training, int count = 1)
  64. {
  65. Tensors dropped_inputs()
  66. {
  67. DropoutArgs args = new DropoutArgs();
  68. args.Rate = rate;
  69. var DropoutLayer = new Dropout(args);
  70. var mask = DropoutLayer.Apply(ones, training: training);
  71. return mask;
  72. }
  73. if (count > 1)
  74. {
  75. Tensors results = new Tensors();
  76. for (int i = 0; i < count; i++)
  77. {
  78. results.Add(dropped_inputs());
  79. }
  80. return results;
  81. }
  82. return dropped_inputs();
  83. }
  84. }
  85. }