From ee0b9355579fb68daebfebe00a809328a101e0a1 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Wed, 17 Jul 2019 19:43:41 -0500 Subject: [PATCH] Add BasicRNNCell, LayerRNNCell, RNNCell, Change nest.flatten to generic. --- docs/source/ConvolutionNeuralNetwork.md | 3 +- src/TensorFlowNET.Core/APIs/tf.nn.cs | 29 +++++++++ .../Operations/BasicRNNCell.cs | 42 ++++++++++++- .../Operations/LayerRNNCell.cs | 33 ++++++++++ src/TensorFlowNET.Core/Operations/RNNCell.cs | 63 +++++++++++++++++++ src/TensorFlowNET.Core/Operations/math_ops.cs | 3 + .../Operations/rnn_cell_impl.cs | 4 +- src/TensorFlowNET.Core/Util/nest.py.cs | 10 +-- .../ImageProcessing/DigitRecognitionRNN.cs | 10 +-- .../nest_test/NestTest.cs | 4 +- 10 files changed, 183 insertions(+), 18 deletions(-) create mode 100644 src/TensorFlowNET.Core/Operations/LayerRNNCell.cs create mode 100644 src/TensorFlowNET.Core/Operations/RNNCell.cs diff --git a/docs/source/ConvolutionNeuralNetwork.md b/docs/source/ConvolutionNeuralNetwork.md index 55dd4265..6b47c9d8 100644 --- a/docs/source/ConvolutionNeuralNetwork.md +++ b/docs/source/ConvolutionNeuralNetwork.md @@ -346,4 +346,5 @@ Get started with the implementation: } ``` -![cnn-reuslt](../assets/cnn-result.png) \ No newline at end of file +![](../assets/cnn-result.png) + diff --git a/src/TensorFlowNET.Core/APIs/tf.nn.cs b/src/TensorFlowNET.Core/APIs/tf.nn.cs index 56f2be59..4c7ba775 100644 --- a/src/TensorFlowNET.Core/APIs/tf.nn.cs +++ b/src/TensorFlowNET.Core/APIs/tf.nn.cs @@ -16,9 +16,11 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Text; using Tensorflow.Operations; using Tensorflow.Operations.Activation; +using Tensorflow.Util; using static Tensorflow.Python; namespace Tensorflow @@ -68,6 +70,33 @@ namespace Tensorflow return nn_ops.dropout_v2(x, rate: rate_tensor, noise_shape: noise_shape, seed: seed, name: name); } + /// + /// Creates a recurrent neural network specified by RNNCell `cell`. + /// + /// An instance of RNNCell. + /// The RNN inputs. + /// + /// + /// + /// A pair (outputs, state) + public static (Tensor, Tensor) dynamic_rnn(RNNCell cell, Tensor inputs, TF_DataType dtype = TF_DataType.DtInvalid, + bool swap_memory = false, bool time_major = false) + { + with(variable_scope("rnn"), scope => + { + VariableScope varscope = scope; + var flat_input = nest.flatten(inputs); + + if (!time_major) + { + flat_input = flat_input.Select(x => ops.convert_to_tensor(x)).ToList(); + //flat_input = flat_input.Select(x => _transpose_batch_time(x)).ToList(); + } + }); + + throw new NotImplementedException(""); + } + public static (Tensor, Tensor) moments(Tensor x, int[] axes, string name = null, diff --git a/src/TensorFlowNET.Core/Operations/BasicRNNCell.cs b/src/TensorFlowNET.Core/Operations/BasicRNNCell.cs index 98ca7e22..1bed4773 100644 --- a/src/TensorFlowNET.Core/Operations/BasicRNNCell.cs +++ b/src/TensorFlowNET.Core/Operations/BasicRNNCell.cs @@ -1,10 +1,48 @@ -using System; +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; using System.Collections.Generic; using System.Text; +using Tensorflow.Keras.Engine; +using Tensorflow.Operations.Activation; namespace Tensorflow { - public class BasicRNNCell + public class BasicRNNCell : LayerRNNCell { + int _num_units; + Func _activation; + + public BasicRNNCell(int num_units, + Func activation = null, + bool? reuse = null, + string name = null, + TF_DataType dtype = TF_DataType.DtInvalid) : base(_reuse: reuse, + name: name, + dtype: dtype) + { + // Inputs must be 2-dimensional. + input_spec = new InputSpec(ndim: 2); + + _num_units = num_units; + if (activation == null) + _activation = math_ops.tanh; + else + _activation = activation; + } } } diff --git a/src/TensorFlowNET.Core/Operations/LayerRNNCell.cs b/src/TensorFlowNET.Core/Operations/LayerRNNCell.cs new file mode 100644 index 00000000..0f9aa254 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/LayerRNNCell.cs @@ -0,0 +1,33 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public class LayerRNNCell : RNNCell + { + public LayerRNNCell(bool? _reuse = null, + string name = null, + TF_DataType dtype = TF_DataType.DtInvalid) : base(_reuse: _reuse, + name: name, + dtype: dtype) + { + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/RNNCell.cs b/src/TensorFlowNET.Core/Operations/RNNCell.cs new file mode 100644 index 00000000..cbfe7db8 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/RNNCell.cs @@ -0,0 +1,63 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + /// + /// Abstract object representing an RNN cell. + /// + /// Every `RNNCell` must have the properties below and implement `call` with + /// the signature `(output, next_state) = call(input, state)`. The optional + /// third input argument, `scope`, is allowed for backwards compatibility + /// purposes; but should be left off for new subclasses. + /// + /// This definition of cell differs from the definition used in the literature. + /// In the literature, 'cell' refers to an object with a single scalar output. + /// This definition refers to a horizontal array of such units. + /// + /// An RNN cell, in the most abstract setting, is anything that has + /// a state and performs some operation that takes a matrix of inputs. + /// This operation results in an output matrix with `self.output_size` columns. + /// If `self.state_size` is an integer, this operation also results in a new + /// state matrix with `self.state_size` columns. If `self.state_size` is a + /// (possibly nested tuple of) TensorShape object(s), then it should return a + /// matching structure of Tensors having shape `[batch_size].concatenate(s)` + /// for each `s` in `self.batch_size`. + /// + public abstract class RNNCell : Layers.Layer + { + /// + /// Attribute that indicates whether the cell is a TF RNN cell, due the slight + /// difference between TF and Keras RNN cell. + /// + protected bool _is_tf_rnn_cell = false; + + public RNNCell(bool trainable = true, + string name = null, + TF_DataType dtype = TF_DataType.DtInvalid, + bool? _reuse = null) : base(trainable: trainable, + name: name, + dtype: dtype, + _reuse: _reuse) + { + _is_tf_rnn_cell = true; + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/math_ops.cs b/src/TensorFlowNET.Core/Operations/math_ops.cs index 9f310bdb..375da903 100644 --- a/src/TensorFlowNET.Core/Operations/math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/math_ops.cs @@ -551,6 +551,9 @@ namespace Tensorflow }); } + public static Tensor tanh(Tensor x, string name = null) + => gen_math_ops.tanh(x, name); + public static Tensor truediv(Tensor x, Tensor y, string name = null) => _truediv_python3(x, y, name); diff --git a/src/TensorFlowNET.Core/Operations/rnn_cell_impl.cs b/src/TensorFlowNET.Core/Operations/rnn_cell_impl.cs index df137766..d0625cc9 100644 --- a/src/TensorFlowNET.Core/Operations/rnn_cell_impl.cs +++ b/src/TensorFlowNET.Core/Operations/rnn_cell_impl.cs @@ -7,8 +7,6 @@ namespace Tensorflow.Operations public class rnn_cell_impl { public BasicRNNCell BasicRNNCell(int num_units) - { - throw new NotImplementedException(); - } + => new BasicRNNCell(num_units); } } diff --git a/src/TensorFlowNET.Core/Util/nest.py.cs b/src/TensorFlowNET.Core/Util/nest.py.cs index 060539cc..47e4e4aa 100644 --- a/src/TensorFlowNET.Core/Util/nest.py.cs +++ b/src/TensorFlowNET.Core/Util/nest.py.cs @@ -214,14 +214,14 @@ namespace Tensorflow.Util //# See the swig file (util.i) for documentation. //flatten = _pywrap_tensorflow.Flatten - public static List flatten(object structure) + public static List flatten(T structure) { - var list = new List(); + var list = new List(); _flatten_recursive(structure, list); return list; } - private static void _flatten_recursive(object obj, List list) + private static void _flatten_recursive(T obj, List list) { if (obj is string) { @@ -232,7 +232,7 @@ namespace Tensorflow.Util { var dict = obj as IDictionary; foreach (var key in _sorted(dict)) - _flatten_recursive(dict[key], list); + _flatten_recursive((T)dict[key], list); return; } if (obj is NDArray) @@ -244,7 +244,7 @@ namespace Tensorflow.Util { var structure = obj as IEnumerable; foreach (var child in structure) - _flatten_recursive(child, list); + _flatten_recursive((T)child, list); return; } list.Add(obj); diff --git a/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionRNN.cs b/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionRNN.cs index a0e72309..7449ba42 100644 --- a/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionRNN.cs +++ b/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionRNN.cs @@ -25,14 +25,12 @@ using static Tensorflow.Python; namespace TensorFlowNET.Examples.ImageProcess { /// - /// Convolutional Neural Network classifier for Hand Written Digits - /// CNN architecture with two convolutional layers, followed by two fully-connected layers at the end. - /// Use Stochastic Gradient Descent (SGD) optimizer. - /// http://www.easy-tensorflow.com/tf-tutorials/convolutional-neural-nets-cnns/cnn1 + /// Recurrent Neural Network for handwritten digits MNIST. + /// https://medium.com/machine-learning-algorithms/mnist-using-recurrent-neural-network-2d070a5915a2 /// public class DigitRecognitionRNN : IExample { - public bool Enabled { get; set; } = false; + public bool Enabled { get; set; } = true; public bool IsImportingGraph { get; set; } = false; public string Name => "MNIST RNN"; @@ -84,6 +82,7 @@ namespace TensorFlowNET.Examples.ImageProcess var X = tf.placeholder(tf.float32, new[] { -1, n_steps, n_inputs }); var y = tf.placeholder(tf.int32, new[] { -1 }); var cell = tf.nn.rnn_cell.BasicRNNCell(num_units: n_neurons); + var (output, state) = tf.nn.dynamic_rnn(cell, X, dtype: tf.float32); return graph; } @@ -154,6 +153,7 @@ namespace TensorFlowNET.Examples.ImageProcess print("Size of:"); print($"- Training-set:\t\t{len(mnist.train.data)}"); print($"- Validation-set:\t{len(mnist.validation.data)}"); + print($"- Test-set:\t\t{len(mnist.test.data)}"); } public Graph ImportGraph() => throw new NotImplementedException(); diff --git a/test/TensorFlowNET.UnitTest/nest_test/NestTest.cs b/test/TensorFlowNET.UnitTest/nest_test/NestTest.cs index 16026ac0..3fd90002 100644 --- a/test/TensorFlowNET.UnitTest/nest_test/NestTest.cs +++ b/test/TensorFlowNET.UnitTest/nest_test/NestTest.cs @@ -78,8 +78,8 @@ namespace TensorFlowNET.UnitTest.nest_test self.assertEqual((((restructured_from_flat[1] as object[])[0] as object[])[0] as Hashtable)["y"], 0); self.assertEqual(new List { 5 }, nest.flatten(5)); - flat = nest.flatten(np.array(new[] { 5 })); - self.assertEqual(new object[] { np.array(new int[] { 5 }) }, flat); + var flat1 = nest.flatten(np.array(new[] { 5 })); + self.assertEqual(new object[] { np.array(new int[] { 5 }) }, flat1); self.assertEqual("a", nest.pack_sequence_as(5, new List { "a" })); self.assertEqual(np.array(new[] { 5 }),