You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Rnn.Test.cs 5.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Linq;
  5. using System.Text;
  6. using System.Threading.Tasks;
  7. using Tensorflow.Common.Types;
  8. using Tensorflow.Keras.Engine;
  9. using Tensorflow.Keras.Layers.Rnn;
  10. using Tensorflow.Keras.Saving;
  11. using Tensorflow.NumPy;
  12. using Tensorflow.Train;
  13. using static Tensorflow.Binding;
  14. using static Tensorflow.KerasApi;
  15. namespace Tensorflow.Keras.UnitTest.Layers
  16. {
  17. [TestClass]
  18. public class Rnn
  19. {
  20. [TestMethod]
  21. public void SimpleRNNCell()
  22. {
  23. var cell = tf.keras.layers.SimpleRNNCell(64, dropout: 0.5f, recurrent_dropout: 0.5f);
  24. var h0 = new Tensors { tf.zeros(new Shape(4, 64)) };
  25. var x = tf.random.normal((4, 100));
  26. var (y, h1) = cell.Apply(inputs: x, states: h0);
  27. var h2 = h1;
  28. Assert.AreEqual((4, 64), y.shape);
  29. Assert.AreEqual((4, 64), h2[0].shape);
  30. }
  31. [TestMethod]
  32. public void StackedRNNCell()
  33. {
  34. var inputs = tf.ones((32, 10));
  35. var states = new Tensors { tf.zeros((32, 4)), tf.zeros((32, 5)) };
  36. var cells = new IRnnCell[] { tf.keras.layers.SimpleRNNCell(4), tf.keras.layers.SimpleRNNCell(5) };
  37. var stackedRNNCell = tf.keras.layers.StackedRNNCells(cells);
  38. var (output, state) = stackedRNNCell.Apply(inputs, states);
  39. Console.WriteLine(output);
  40. Console.WriteLine(state.shape);
  41. Assert.AreEqual((32, 5), output.shape);
  42. Assert.AreEqual((32, 4), state[0].shape);
  43. }
  44. [TestMethod]
  45. public void LSTMCell()
  46. {
  47. var inputs = tf.ones((2, 100));
  48. var states = new Tensors { tf.zeros((2, 4)), tf.zeros((2, 4)) };
  49. var rnn = tf.keras.layers.LSTMCell(4);
  50. var (output, new_states) = rnn.Apply(inputs, states);
  51. Assert.AreEqual((2, 4), output.shape);
  52. Assert.AreEqual((2, 4), new_states[0].shape);
  53. }
  54. [TestMethod]
  55. public void TrainLSTMWithMnist()
  56. {
  57. var input = keras.Input((784));
  58. var x = keras.layers.Reshape((28, 28)).Apply(input);
  59. //x = keras.layers.LSTM(50, return_sequences: true).Apply(x);
  60. //x = keras.layers.LSTM(100, return_sequences: true).Apply(x);
  61. //x = keras.layers.LSTM(150, return_sequences: true).Apply(x);
  62. x = keras.layers.LSTM(4, implementation: 2).Apply(x);
  63. //x = keras.layers.Dense(100).Apply(x);
  64. var output = keras.layers.Dense(10, activation: "softmax").Apply(x);
  65. var model = keras.Model(input, output);
  66. model.summary();
  67. model.compile(keras.optimizers.Adam(), keras.losses.SparseCategoricalCrossentropy(), new string[] { "accuracy" });
  68. var data_loader = new MnistModelLoader();
  69. var dataset = data_loader.LoadAsync(new ModelLoadSetting
  70. {
  71. TrainDir = "mnist",
  72. OneHot = false,
  73. ValidationSize = 58000,
  74. }).Result;
  75. model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size: 16, epochs: 30);
  76. }
  77. [TestMethod]
  78. public void SimpleRNN()
  79. {
  80. var input = keras.Input((784));
  81. var x = keras.layers.Reshape((28, 28)).Apply(input);
  82. x = keras.layers.SimpleRNN(10).Apply(x);
  83. var output = keras.layers.Dense(10, activation: "softmax").Apply(x);
  84. var model = keras.Model(input, output);
  85. model.summary();
  86. model.compile(keras.optimizers.Adam(), keras.losses.CategoricalCrossentropy(), new string[] { "accuracy" });
  87. var data_loader = new MnistModelLoader();
  88. var dataset = data_loader.LoadAsync(new ModelLoadSetting
  89. {
  90. TrainDir = "mnist",
  91. OneHot = false,
  92. ValidationSize = 58000,
  93. }).Result;
  94. model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size: 16, epochs: 10);
  95. }
  96. [TestMethod]
  97. public void RNNForSimpleRNNCell()
  98. {
  99. var inputs = tf.random.normal((32, 10, 8));
  100. var cell = tf.keras.layers.SimpleRNNCell(10, dropout: 0.5f, recurrent_dropout: 0.5f);
  101. var rnn = tf.keras.layers.RNN(cell: cell);
  102. var output = rnn.Apply(inputs);
  103. Assert.AreEqual((32, 10), output.shape);
  104. }
  105. [TestMethod]
  106. public void RNNForStackedRNNCell()
  107. {
  108. var inputs = tf.random.normal((32, 10, 8));
  109. var cells = new IRnnCell[] { tf.keras.layers.SimpleRNNCell(4), tf.keras.layers.SimpleRNNCell(5) };
  110. var stackedRNNCell = tf.keras.layers.StackedRNNCells(cells);
  111. var rnn = tf.keras.layers.RNN(cell: stackedRNNCell);
  112. var output = rnn.Apply(inputs);
  113. Assert.AreEqual((32, 5), output.shape);
  114. }
  115. [TestMethod]
  116. public void RNNForLSTMCell()
  117. {
  118. var inputs = tf.ones((5, 10, 8));
  119. var rnn = tf.keras.layers.RNN(tf.keras.layers.LSTMCell(4));
  120. var output = rnn.Apply(inputs);
  121. Console.WriteLine($"output: {output}");
  122. Assert.AreEqual((5, 4), output.shape);
  123. }
  124. }
  125. }