You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Rnn.Test.cs 6.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Linq;
  5. using System.Text;
  6. using System.Threading.Tasks;
  7. using Tensorflow.Common.Types;
  8. using Tensorflow.Keras.ArgsDefinition;
  9. using Tensorflow.Keras.Engine;
  10. using Tensorflow.Keras.Layers;
  11. using Tensorflow.Keras.Saving;
  12. using Tensorflow.NumPy;
  13. using Tensorflow.Train;
  14. using static Tensorflow.Binding;
  15. using static Tensorflow.KerasApi;
  16. namespace Tensorflow.Keras.UnitTest.Layers
  17. {
  18. [TestClass]
  19. public class Rnn
  20. {
  21. [TestMethod]
  22. public void SimpleRNNCell()
  23. {
  24. var cell = tf.keras.layers.SimpleRNNCell(64, dropout: 0.5f, recurrent_dropout: 0.5f);
  25. var h0 = new Tensors { tf.zeros(new Shape(4, 64)) };
  26. var x = tf.random.normal((4, 100));
  27. var (y, h1) = cell.Apply(inputs: x, states: h0);
  28. var h2 = h1;
  29. Assert.AreEqual((4, 64), y.shape);
  30. Assert.AreEqual((4, 64), h2[0].shape);
  31. }
  32. [TestMethod]
  33. public void StackedRNNCell()
  34. {
  35. var inputs = tf.ones((32, 10));
  36. var states = new Tensors { tf.zeros((32, 4)), tf.zeros((32, 5)) };
  37. var cells = new IRnnCell[] { tf.keras.layers.SimpleRNNCell(4), tf.keras.layers.SimpleRNNCell(5) };
  38. var stackedRNNCell = tf.keras.layers.StackedRNNCells(cells);
  39. var (output, state) = stackedRNNCell.Apply(inputs, states);
  40. Assert.AreEqual((32, 5), output.shape);
  41. Assert.AreEqual((32, 4), state[0].shape);
  42. }
  43. [TestMethod]
  44. public void LSTMCell()
  45. {
  46. var inputs = tf.ones((2, 100));
  47. var states = new Tensors { tf.zeros((2, 4)), tf.zeros((2, 4)) };
  48. var rnn = tf.keras.layers.LSTMCell(4);
  49. var (output, new_states) = rnn.Apply(inputs, states);
  50. Assert.AreEqual((2, 4), output.shape);
  51. Assert.AreEqual((2, 4), new_states[0].shape);
  52. }
  53. [TestMethod]
  54. public void TrainLSTMWithMnist()
  55. {
  56. var input = keras.Input((784));
  57. var x = keras.layers.Reshape((28, 28)).Apply(input);
  58. x = keras.layers.LSTM(50, return_sequences: true).Apply(x);
  59. x = keras.layers.LSTM(100).Apply(x);
  60. var output = keras.layers.Dense(10, activation: "softmax").Apply(x);
  61. var model = keras.Model(input, output);
  62. model.summary();
  63. model.compile(keras.optimizers.Adam(), keras.losses.CategoricalCrossentropy(), new string[] { "accuracy" });
  64. var data_loader = new MnistModelLoader();
  65. var dataset = data_loader.LoadAsync(new ModelLoadSetting
  66. {
  67. TrainDir = "mnist",
  68. OneHot = true,
  69. ValidationSize = 55000,
  70. }).Result;
  71. var sample_weight = np.ones(((int)dataset.Train.Data.shape[0]));
  72. model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size: 16, epochs: 1, sample_weight:sample_weight);
  73. }
  74. [TestMethod]
  75. public void SimpleRNN()
  76. {
  77. var input = keras.Input((784));
  78. var x = keras.layers.Reshape((28, 28)).Apply(input);
  79. x = keras.layers.SimpleRNN(10).Apply(x);
  80. var output = keras.layers.Dense(10, activation: "softmax").Apply(x);
  81. var model = keras.Model(input, output);
  82. model.summary();
  83. model.compile(keras.optimizers.Adam(), keras.losses.CategoricalCrossentropy(), new string[] { "accuracy" });
  84. var data_loader = new MnistModelLoader();
  85. var dataset = data_loader.LoadAsync(new ModelLoadSetting
  86. {
  87. TrainDir = "mnist",
  88. OneHot = false,
  89. ValidationSize = 58000,
  90. }).Result;
  91. model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size: 16, epochs: 2);
  92. }
  93. [TestMethod]
  94. public void RNNForSimpleRNNCell()
  95. {
  96. var inputs = tf.random.normal((32, 10, 8));
  97. var cell = tf.keras.layers.SimpleRNNCell(10, dropout: 0.5f, recurrent_dropout: 0.5f);
  98. var rnn = tf.keras.layers.RNN(cell: cell);
  99. var cgf = rnn.get_config();
  100. var output = rnn.Apply(inputs);
  101. Assert.AreEqual((32, 10), output.shape);
  102. }
  103. [TestMethod]
  104. public void RNNForStackedRNNCell()
  105. {
  106. var inputs = tf.random.normal((32, 10, 8));
  107. var cells = new IRnnCell[] { tf.keras.layers.SimpleRNNCell(4), tf.keras.layers.SimpleRNNCell(5) };
  108. var stackedRNNCell = tf.keras.layers.StackedRNNCells(cells);
  109. var rnn = tf.keras.layers.RNN(cell: stackedRNNCell);
  110. var output = rnn.Apply(inputs);
  111. Assert.AreEqual((32, 5), output.shape);
  112. }
  113. [TestMethod]
  114. public void RNNForLSTMCell()
  115. {
  116. var inputs = tf.ones((5, 10, 8));
  117. var rnn = tf.keras.layers.RNN(tf.keras.layers.LSTMCell(4));
  118. var output = rnn.Apply(inputs);
  119. Console.WriteLine($"output: {output}");
  120. Assert.AreEqual((5, 4), output.shape);
  121. }
  122. [TestMethod]
  123. public void GRUCell()
  124. {
  125. var inputs = tf.random.normal((32, 10, 8));
  126. var rnn = tf.keras.layers.RNN(tf.keras.layers.GRUCell(4));
  127. var output = rnn.Apply(inputs);
  128. Assert.AreEqual((32, 4), output.shape);
  129. rnn = tf.keras.layers.RNN(tf.keras.layers.GRUCell(4, reset_after:false, use_bias:false));
  130. output = rnn.Apply(inputs);
  131. Assert.AreEqual((32, 4), output.shape);
  132. }
  133. [TestMethod]
  134. public void GRU()
  135. {
  136. var inputs = tf.ones((32, 10, 8));
  137. var gru = tf.keras.layers.GRU(4);
  138. var output = gru.Apply(inputs);
  139. Assert.AreEqual((32, 4), output.shape);
  140. }
  141. [TestMethod]
  142. public void Bidirectional()
  143. {
  144. var bi = tf.keras.layers.Bidirectional(keras.layers.LSTM(10, return_sequences:true));
  145. var inputs = tf.random.normal((32, 10, 8));
  146. var outputs = bi.Apply(inputs);
  147. Assert.AreEqual((32, 10, 20), outputs.shape);
  148. }
  149. }
  150. }