diff --git a/docs/source/ConvolutionNeuralNetwork.md b/docs/source/ConvolutionNeuralNetwork.md
index 55dd4265..6b47c9d8 100644
--- a/docs/source/ConvolutionNeuralNetwork.md
+++ b/docs/source/ConvolutionNeuralNetwork.md
@@ -346,4 +346,5 @@ Get started with the implementation:
}
```
-
\ No newline at end of file
+
+
diff --git a/src/TensorFlowNET.Core/APIs/tf.nn.cs b/src/TensorFlowNET.Core/APIs/tf.nn.cs
index 56f2be59..4c7ba775 100644
--- a/src/TensorFlowNET.Core/APIs/tf.nn.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.nn.cs
@@ -16,9 +16,11 @@
using System;
using System.Collections.Generic;
+using System.Linq;
using System.Text;
using Tensorflow.Operations;
using Tensorflow.Operations.Activation;
+using Tensorflow.Util;
using static Tensorflow.Python;
namespace Tensorflow
@@ -68,6 +70,33 @@ namespace Tensorflow
return nn_ops.dropout_v2(x, rate: rate_tensor, noise_shape: noise_shape, seed: seed, name: name);
}
+ ///
+ /// Creates a recurrent neural network specified by RNNCell `cell`.
+ ///
+ /// An instance of RNNCell.
+ /// The RNN inputs.
+ ///
+ ///
+ ///
+ /// A pair (outputs, state)
+ public static (Tensor, Tensor) dynamic_rnn(RNNCell cell, Tensor inputs, TF_DataType dtype = TF_DataType.DtInvalid,
+ bool swap_memory = false, bool time_major = false)
+ {
+ with(variable_scope("rnn"), scope =>
+ {
+ VariableScope varscope = scope;
+ var flat_input = nest.flatten(inputs);
+
+ if (!time_major)
+ {
+ flat_input = flat_input.Select(x => ops.convert_to_tensor(x)).ToList();
+ //flat_input = flat_input.Select(x => _transpose_batch_time(x)).ToList();
+ }
+ });
+
+ throw new NotImplementedException("");
+ }
+
public static (Tensor, Tensor) moments(Tensor x,
int[] axes,
string name = null,
diff --git a/src/TensorFlowNET.Core/Operations/BasicRNNCell.cs b/src/TensorFlowNET.Core/Operations/BasicRNNCell.cs
index 98ca7e22..1bed4773 100644
--- a/src/TensorFlowNET.Core/Operations/BasicRNNCell.cs
+++ b/src/TensorFlowNET.Core/Operations/BasicRNNCell.cs
@@ -1,10 +1,48 @@
-using System;
+/*****************************************************************************
+ Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+******************************************************************************/
+
+using System;
using System.Collections.Generic;
using System.Text;
+using Tensorflow.Keras.Engine;
+using Tensorflow.Operations.Activation;
namespace Tensorflow
{
- public class BasicRNNCell
+ public class BasicRNNCell : LayerRNNCell
{
+ int _num_units;
+ Func _activation;
+
+ public BasicRNNCell(int num_units,
+ Func activation = null,
+ bool? reuse = null,
+ string name = null,
+ TF_DataType dtype = TF_DataType.DtInvalid) : base(_reuse: reuse,
+ name: name,
+ dtype: dtype)
+ {
+ // Inputs must be 2-dimensional.
+ input_spec = new InputSpec(ndim: 2);
+
+ _num_units = num_units;
+ if (activation == null)
+ _activation = math_ops.tanh;
+ else
+ _activation = activation;
+ }
}
}
diff --git a/src/TensorFlowNET.Core/Operations/LayerRNNCell.cs b/src/TensorFlowNET.Core/Operations/LayerRNNCell.cs
new file mode 100644
index 00000000..0f9aa254
--- /dev/null
+++ b/src/TensorFlowNET.Core/Operations/LayerRNNCell.cs
@@ -0,0 +1,33 @@
+/*****************************************************************************
+ Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+******************************************************************************/
+
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow
+{
+ public class LayerRNNCell : RNNCell
+ {
+ public LayerRNNCell(bool? _reuse = null,
+ string name = null,
+ TF_DataType dtype = TF_DataType.DtInvalid) : base(_reuse: _reuse,
+ name: name,
+ dtype: dtype)
+ {
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Operations/RNNCell.cs b/src/TensorFlowNET.Core/Operations/RNNCell.cs
new file mode 100644
index 00000000..cbfe7db8
--- /dev/null
+++ b/src/TensorFlowNET.Core/Operations/RNNCell.cs
@@ -0,0 +1,63 @@
+/*****************************************************************************
+ Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+******************************************************************************/
+
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow
+{
+ ///
+ /// Abstract object representing an RNN cell.
+ ///
+ /// Every `RNNCell` must have the properties below and implement `call` with
+ /// the signature `(output, next_state) = call(input, state)`. The optional
+ /// third input argument, `scope`, is allowed for backwards compatibility
+ /// purposes; but should be left off for new subclasses.
+ ///
+ /// This definition of cell differs from the definition used in the literature.
+ /// In the literature, 'cell' refers to an object with a single scalar output.
+ /// This definition refers to a horizontal array of such units.
+ ///
+ /// An RNN cell, in the most abstract setting, is anything that has
+ /// a state and performs some operation that takes a matrix of inputs.
+ /// This operation results in an output matrix with `self.output_size` columns.
+ /// If `self.state_size` is an integer, this operation also results in a new
+ /// state matrix with `self.state_size` columns. If `self.state_size` is a
+ /// (possibly nested tuple of) TensorShape object(s), then it should return a
+ /// matching structure of Tensors having shape `[batch_size].concatenate(s)`
+ /// for each `s` in `self.batch_size`.
+ ///
+ public abstract class RNNCell : Layers.Layer
+ {
+ ///
+ /// Attribute that indicates whether the cell is a TF RNN cell, due the slight
+ /// difference between TF and Keras RNN cell.
+ ///
+ protected bool _is_tf_rnn_cell = false;
+
+ public RNNCell(bool trainable = true,
+ string name = null,
+ TF_DataType dtype = TF_DataType.DtInvalid,
+ bool? _reuse = null) : base(trainable: trainable,
+ name: name,
+ dtype: dtype,
+ _reuse: _reuse)
+ {
+ _is_tf_rnn_cell = true;
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Operations/math_ops.cs b/src/TensorFlowNET.Core/Operations/math_ops.cs
index 9f310bdb..375da903 100644
--- a/src/TensorFlowNET.Core/Operations/math_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/math_ops.cs
@@ -551,6 +551,9 @@ namespace Tensorflow
});
}
+ public static Tensor tanh(Tensor x, string name = null)
+ => gen_math_ops.tanh(x, name);
+
public static Tensor truediv(Tensor x, Tensor y, string name = null)
=> _truediv_python3(x, y, name);
diff --git a/src/TensorFlowNET.Core/Operations/rnn_cell_impl.cs b/src/TensorFlowNET.Core/Operations/rnn_cell_impl.cs
index df137766..d0625cc9 100644
--- a/src/TensorFlowNET.Core/Operations/rnn_cell_impl.cs
+++ b/src/TensorFlowNET.Core/Operations/rnn_cell_impl.cs
@@ -7,8 +7,6 @@ namespace Tensorflow.Operations
public class rnn_cell_impl
{
public BasicRNNCell BasicRNNCell(int num_units)
- {
- throw new NotImplementedException();
- }
+ => new BasicRNNCell(num_units);
}
}
diff --git a/src/TensorFlowNET.Core/Util/nest.py.cs b/src/TensorFlowNET.Core/Util/nest.py.cs
index 060539cc..47e4e4aa 100644
--- a/src/TensorFlowNET.Core/Util/nest.py.cs
+++ b/src/TensorFlowNET.Core/Util/nest.py.cs
@@ -214,14 +214,14 @@ namespace Tensorflow.Util
//# See the swig file (util.i) for documentation.
//flatten = _pywrap_tensorflow.Flatten
- public static List