You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

RNN.cs 3.7 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. using System;
  2. using System.Collections.Generic;
  3. using Tensorflow.Keras.ArgsDefinition;
  4. using Tensorflow.Keras.Engine;
  5. namespace Tensorflow.Keras.Layers
  6. {
  7. public class RNN : Layer
  8. {
  9. private RNNArgs args;
  10. public RNN(RNNArgs args) : base(PreConstruct(args))
  11. {
  12. this.args = args;
  13. SupportsMasking = true;
  14. // The input shape is unknown yet, it could have nested tensor inputs, and
  15. // the input spec will be the list of specs for nested inputs, the structure
  16. // of the input_spec will be the same as the input.
  17. //self.input_spec = None
  18. //self.state_spec = None
  19. //self._states = None
  20. //self.constants_spec = None
  21. //self._num_constants = 0
  22. //if stateful:
  23. // if ds_context.has_strategy():
  24. // raise ValueError('RNNs with stateful=True not yet supported with '
  25. // 'tf.distribute.Strategy.')
  26. }
  27. private static RNNArgs PreConstruct(RNNArgs args)
  28. {
  29. if (args.Kwargs == null)
  30. {
  31. args.Kwargs = new Dictionary<string, object>();
  32. }
  33. // If true, the output for masked timestep will be zeros, whereas in the
  34. // false case, output from previous timestep is returned for masked timestep.
  35. var zeroOutputForMask = (bool)args.Kwargs.Get("zero_output_for_mask", false);
  36. object input_shape;
  37. var propIS = args.Kwargs.Get("input_shape", null);
  38. var propID = args.Kwargs.Get("input_dim", null);
  39. var propIL = args.Kwargs.Get("input_length", null);
  40. if (propIS == null && (propID != null || propIL != null))
  41. {
  42. input_shape = (
  43. propIL ?? new NoneValue(), // maybe null is needed here
  44. propID ?? new NoneValue()); // and here
  45. args.Kwargs["input_shape"] = input_shape;
  46. }
  47. return args;
  48. }
  49. public RNN New(LayerRnnCell cell,
  50. bool return_sequences = false,
  51. bool return_state = false,
  52. bool go_backwards = false,
  53. bool stateful = false,
  54. bool unroll = false,
  55. bool time_major = false)
  56. => new RNN(new RNNArgs
  57. {
  58. Cell = cell,
  59. ReturnSequences = return_sequences,
  60. ReturnState = return_state,
  61. GoBackwards = go_backwards,
  62. Stateful = stateful,
  63. Unroll = unroll,
  64. TimeMajor = time_major
  65. });
  66. public RNN New(IList<RnnCell> cell,
  67. bool return_sequences = false,
  68. bool return_state = false,
  69. bool go_backwards = false,
  70. bool stateful = false,
  71. bool unroll = false,
  72. bool time_major = false)
  73. => new RNN(new RNNArgs
  74. {
  75. Cell = new StackedRNNCells(new StackedRNNCellsArgs { Cells = cell }),
  76. ReturnSequences = return_sequences,
  77. ReturnState = return_state,
  78. GoBackwards = go_backwards,
  79. Stateful = stateful,
  80. Unroll = unroll,
  81. TimeMajor = time_major
  82. });
  83. protected Tensor get_initial_state(Tensor inputs)
  84. {
  85. return _generate_zero_filled_state_for_cell(null, null);
  86. }
  87. Tensor _generate_zero_filled_state_for_cell(LSTMCell cell, Tensor batch_size)
  88. {
  89. throw new NotImplementedException("");
  90. }
  91. }
  92. }