You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Rnn.Test.cs 4.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Linq;
  5. using System.Text;
  6. using System.Threading.Tasks;
  7. using Tensorflow.Common.Types;
  8. using Tensorflow.Keras.Engine;
  9. using Tensorflow.Keras.Layers.Rnn;
  10. using Tensorflow.Keras.Saving;
  11. using Tensorflow.NumPy;
  12. using Tensorflow.Train;
  13. using static Tensorflow.Binding;
  14. using static Tensorflow.KerasApi;
  15. namespace Tensorflow.Keras.UnitTest.Layers
  16. {
  17. [TestClass]
  18. public class Rnn
  19. {
  20. [TestMethod]
  21. public void SimpleRNNCell()
  22. {
  23. //var cell = tf.keras.layers.SimpleRNNCell(64, dropout: 0.5f, recurrent_dropout: 0.5f);
  24. //var h0 = new Tensors { tf.zeros(new Shape(4, 64)) };
  25. //var x = tf.random.normal((4, 100));
  26. //var (y, h1) = cell.Apply(inputs: x, states: h0);
  27. //var h2 = h1;
  28. //Assert.AreEqual((4, 64), y.shape);
  29. //Assert.AreEqual((4, 64), h2[0].shape);
  30. //var model = keras.Sequential(new List<ILayer>
  31. //{
  32. // keras.layers.InputLayer(input_shape: (4,100)),
  33. // keras.layers.SimpleRNNCell(64)
  34. //});
  35. //model.summary();
  36. var cell = tf.keras.layers.SimpleRNNCell(64, dropout: 0.5f, recurrent_dropout: 0.5f);
  37. var h0 = new Tensors { tf.zeros(new Shape(4, 64)) };
  38. var x = tf.random.normal((4, 100));
  39. var (y, h1) = cell.Apply(inputs: x, states: h0);
  40. var h2 = h1;
  41. Assert.AreEqual((4, 64), y.shape);
  42. Assert.AreEqual((4, 64), h2[0].shape);
  43. }
  44. [TestMethod]
  45. public void StackedRNNCell()
  46. {
  47. var inputs = tf.ones((32, 10));
  48. var states = new Tensors { tf.zeros((32, 4)), tf.zeros((32, 5)) };
  49. var cells = new IRnnCell[] { tf.keras.layers.SimpleRNNCell(4), tf.keras.layers.SimpleRNNCell(5) };
  50. var stackedRNNCell = tf.keras.layers.StackedRNNCells(cells);
  51. var (output, state) = stackedRNNCell.Apply(inputs, states);
  52. Console.WriteLine(output);
  53. Console.WriteLine(state.shape);
  54. Assert.AreEqual((32, 5), output.shape);
  55. Assert.AreEqual((32, 4), state[0].shape);
  56. }
  57. [TestMethod]
  58. public void SimpleRNN()
  59. {
  60. //var inputs = np.arange(6 * 10 * 8).reshape((6, 10, 8)).astype(np.float32);
  61. ///*var simple_rnn = keras.layers.SimpleRNN(4);
  62. //var output = simple_rnn.Apply(inputs);
  63. //Assert.AreEqual((32, 4), output.shape);*/
  64. //var simple_rnn = tf.keras.layers.SimpleRNN(4, return_sequences: true, return_state: true);
  65. //var (whole_sequence_output, final_state) = simple_rnn.Apply(inputs);
  66. //Assert.AreEqual((6, 10, 4), whole_sequence_output.shape);
  67. //Assert.AreEqual((6, 4), final_state.shape);
  68. var inputs = keras.Input(shape: (10, 8));
  69. var x = keras.layers.SimpleRNN(4).Apply(inputs);
  70. var output = keras.layers.Dense(10).Apply(x);
  71. var model = keras.Model(inputs, output);
  72. model.summary();
  73. model.compile(keras.optimizers.Adam(), keras.losses.SparseCategoricalCrossentropy());
  74. var datax = np.ones((16, 10, 8), dtype: dtypes.float32);
  75. var datay = np.ones((16));
  76. model.fit(datax, datay, epochs: 20);
  77. }
  78. [TestMethod]
  79. public void RNNForSimpleRNNCell()
  80. {
  81. var inputs = tf.random.normal((32, 10, 8));
  82. var cell = tf.keras.layers.SimpleRNNCell(10, dropout: 0.5f, recurrent_dropout: 0.5f);
  83. var rnn = tf.keras.layers.RNN(cell: cell);
  84. var output = rnn.Apply(inputs);
  85. Assert.AreEqual((32, 10), output.shape);
  86. }
  87. [TestMethod]
  88. public void RNNForStackedRNNCell()
  89. {
  90. var inputs = tf.random.normal((32, 10, 8));
  91. var cells = new IRnnCell[] { tf.keras.layers.SimpleRNNCell(4), tf.keras.layers.SimpleRNNCell(5) };
  92. var stackedRNNCell = tf.keras.layers.StackedRNNCells(cells);
  93. var rnn = tf.keras.layers.RNN(cell: stackedRNNCell);
  94. var output = rnn.Apply(inputs);
  95. Assert.AreEqual((32, 5), output.shape);
  96. }
  97. [TestMethod]
  98. public void WlzTest()
  99. {
  100. long[] b = { 1, 2, 3 };
  101. Shape a = new Shape(Unknown).concatenate(b);
  102. Console.WriteLine(a);
  103. }
  104. }
  105. }