using Microsoft.VisualStudio.TestTools.UnitTesting; using System; using System.Collections.Generic; using System.Linq; using System.Text; using System.Threading.Tasks; using Tensorflow.Common.Types; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Layers.Rnn; using Tensorflow.Keras.Saving; using Tensorflow.NumPy; using Tensorflow.Train; using static Tensorflow.Binding; using static Tensorflow.KerasApi; namespace Tensorflow.Keras.UnitTest.Layers { [TestClass] public class Rnn { [TestMethod] public void SimpleRNNCell() { var cell = tf.keras.layers.SimpleRNNCell(64, dropout: 0.5f, recurrent_dropout: 0.5f); var h0 = new Tensors { tf.zeros(new Shape(4, 64)) }; var x = tf.random.normal((4, 100)); var (y, h1) = cell.Apply(inputs: x, states: h0); var h2 = h1; Assert.AreEqual((4, 64), y.shape); Assert.AreEqual((4, 64), h2[0].shape); } [TestMethod] public void StackedRNNCell() { var inputs = tf.ones((32, 10)); var states = new Tensors { tf.zeros((32, 4)), tf.zeros((32, 5)) }; var cells = new IRnnCell[] { tf.keras.layers.SimpleRNNCell(4), tf.keras.layers.SimpleRNNCell(5) }; var stackedRNNCell = tf.keras.layers.StackedRNNCells(cells); var (output, state) = stackedRNNCell.Apply(inputs, states); Console.WriteLine(output); Console.WriteLine(state.shape); Assert.AreEqual((32, 5), output.shape); Assert.AreEqual((32, 4), state[0].shape); } [TestMethod] public void LSTMCell() { var inputs = tf.ones((2, 100)); var states = new Tensors { tf.zeros((2, 4)), tf.zeros((2, 4)) }; var rnn = tf.keras.layers.LSTMCell(4); var (output, new_states) = rnn.Apply(inputs, states); Assert.AreEqual((2, 4), output.shape); Assert.AreEqual((2, 4), new_states[0].shape); } [TestMethod] public void TrainLSTMWithMnist() { var input = keras.Input((784)); var x = keras.layers.Reshape((28, 28)).Apply(input); //x = keras.layers.LSTM(50, return_sequences: true).Apply(x); //x = keras.layers.LSTM(100, return_sequences: true).Apply(x); //x = keras.layers.LSTM(150, return_sequences: true).Apply(x); x = keras.layers.LSTM(4, implementation: 2).Apply(x); //x = keras.layers.Dense(100).Apply(x); var output = keras.layers.Dense(10, activation: "softmax").Apply(x); var model = keras.Model(input, output); model.summary(); model.compile(keras.optimizers.Adam(), keras.losses.SparseCategoricalCrossentropy(), new string[] { "accuracy" }); var data_loader = new MnistModelLoader(); var dataset = data_loader.LoadAsync(new ModelLoadSetting { TrainDir = "mnist", OneHot = false, ValidationSize = 58000, }).Result; model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size: 16, epochs: 30); } [TestMethod] public void SimpleRNN() { var input = keras.Input((784)); var x = keras.layers.Reshape((28, 28)).Apply(input); x = keras.layers.SimpleRNN(10).Apply(x); var output = keras.layers.Dense(10, activation: "softmax").Apply(x); var model = keras.Model(input, output); model.summary(); model.compile(keras.optimizers.Adam(), keras.losses.CategoricalCrossentropy(), new string[] { "accuracy" }); var data_loader = new MnistModelLoader(); var dataset = data_loader.LoadAsync(new ModelLoadSetting { TrainDir = "mnist", OneHot = false, ValidationSize = 58000, }).Result; model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size: 16, epochs: 10); } [TestMethod] public void RNNForSimpleRNNCell() { var inputs = tf.random.normal((32, 10, 8)); var cell = tf.keras.layers.SimpleRNNCell(10, dropout: 0.5f, recurrent_dropout: 0.5f); var rnn = tf.keras.layers.RNN(cell: cell); var output = rnn.Apply(inputs); Assert.AreEqual((32, 10), output.shape); } [TestMethod] public void RNNForStackedRNNCell() { var inputs = tf.random.normal((32, 10, 8)); var cells = new IRnnCell[] { tf.keras.layers.SimpleRNNCell(4), tf.keras.layers.SimpleRNNCell(5) }; var stackedRNNCell = tf.keras.layers.StackedRNNCells(cells); var rnn = tf.keras.layers.RNN(cell: stackedRNNCell); var output = rnn.Apply(inputs); Assert.AreEqual((32, 5), output.shape); } [TestMethod] public void RNNForLSTMCell() { var inputs = tf.ones((5, 10, 8)); var rnn = tf.keras.layers.RNN(tf.keras.layers.LSTMCell(4)); var output = rnn.Apply(inputs); Console.WriteLine($"output: {output}"); Assert.AreEqual((5, 4), output.shape); } } }