You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

LSTM.cs 1.3 kB

4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041
  1. using System.Linq;
  2. using Tensorflow.Keras.ArgsDefinition.Rnn;
  3. using Tensorflow.Keras.Engine;
  4. using Tensorflow.Common.Types;
  5. namespace Tensorflow.Keras.Layers.Rnn
  6. {
  7. /// <summary>
  8. /// Long Short-Term Memory layer - Hochreiter 1997.
  9. ///
  10. /// See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn)
  11. /// for details about the usage of RNN API.
  12. /// </summary>
  13. public class LSTM : RNN
  14. {
  15. LSTMArgs args;
  16. InputSpec[] state_spec;
  17. int units => args.Units;
  18. public LSTM(LSTMArgs args) :
  19. base(args)
  20. {
  21. this.args = args;
  22. state_spec = new[] { units, units }
  23. .Select(dim => new InputSpec(shape: (-1, dim)))
  24. .ToArray();
  25. }
  26. <<<<<<< HEAD
  27. protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
  28. {
  29. return base.Call(inputs, initial_state: state, training: training);
  30. =======
  31. protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null)
  32. {
  33. return base.Call(inputs, initial_state: initial_state, training: training);
  34. >>>>>>> master
  35. }
  36. }
  37. }