You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

DropoutRNNCellMixin.cs 2.5 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Text;
  4. using Tensorflow.Common.Types;
  5. using Tensorflow.Keras.ArgsDefinition;
  6. using Tensorflow.Keras.Engine;
  7. namespace Tensorflow.Keras.Layers.Rnn
  8. {
  9. public abstract class DropoutRNNCellMixin: RnnCellBase
  10. {
  11. public float dropout;
  12. public float recurrent_dropout;
  13. // TODO(Rinne): deal with cache.
  14. public DropoutRNNCellMixin(LayerArgs args): base(args)
  15. {
  16. }
  17. public Tensors? get_dropout_maskcell_for_cell(Tensors input, bool training, int count = 1)
  18. {
  19. if (dropout == 0f)
  20. return null;
  21. return _generate_dropout_mask(
  22. tf.ones_like(input),
  23. dropout,
  24. training,
  25. count);
  26. }
  27. // Get the recurrent dropout mask for RNN cell.
  28. public Tensors? get_recurrent_dropout_maskcell_for_cell(Tensors input, bool training, int count = 1)
  29. {
  30. if (dropout == 0f)
  31. return null;
  32. return _generate_dropout_mask(
  33. tf.ones_like(input),
  34. recurrent_dropout,
  35. training,
  36. count);
  37. }
  38. public Tensors _create_dropout_mask(Tensors input, bool training, int count = 1)
  39. {
  40. return _generate_dropout_mask(
  41. tf.ones_like(input),
  42. dropout,
  43. training,
  44. count);
  45. }
  46. public Tensors _create_recurrent_dropout_mask(Tensors input, bool training, int count = 1)
  47. {
  48. return _generate_dropout_mask(
  49. tf.ones_like(input),
  50. recurrent_dropout,
  51. training,
  52. count);
  53. }
  54. public Tensors _generate_dropout_mask(Tensor ones, float rate, bool training, int count = 1)
  55. {
  56. Tensors dropped_inputs()
  57. {
  58. DropoutArgs args = new DropoutArgs();
  59. args.Rate = rate;
  60. var DropoutLayer = new Dropout(args);
  61. var mask = DropoutLayer.Apply(ones, training: training);
  62. return mask;
  63. }
  64. if (count > 1)
  65. {
  66. Tensors results = new Tensors();
  67. for (int i = 0; i < count; i++)
  68. {
  69. results.Add(dropped_inputs());
  70. }
  71. return results;
  72. }
  73. return dropped_inputs();
  74. }
  75. }
  76. }