@@ -346,4 +346,5 @@ Get started with the implementation: | |||
} | |||
``` | |||
 | |||
 | |||
@@ -16,9 +16,11 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
using Tensorflow.Operations; | |||
using Tensorflow.Operations.Activation; | |||
using Tensorflow.Util; | |||
using static Tensorflow.Python; | |||
namespace Tensorflow | |||
@@ -68,6 +70,33 @@ namespace Tensorflow | |||
return nn_ops.dropout_v2(x, rate: rate_tensor, noise_shape: noise_shape, seed: seed, name: name); | |||
} | |||
/// <summary> | |||
/// Creates a recurrent neural network specified by RNNCell `cell`. | |||
/// </summary> | |||
/// <param name="cell">An instance of RNNCell.</param> | |||
/// <param name="inputs">The RNN inputs.</param> | |||
/// <param name="dtype"></param> | |||
/// <param name="swap_memory"></param> | |||
/// <param name="time_major"></param> | |||
/// <returns>A pair (outputs, state)</returns> | |||
public static (Tensor, Tensor) dynamic_rnn(RNNCell cell, Tensor inputs, TF_DataType dtype = TF_DataType.DtInvalid, | |||
bool swap_memory = false, bool time_major = false) | |||
{ | |||
with(variable_scope("rnn"), scope => | |||
{ | |||
VariableScope varscope = scope; | |||
var flat_input = nest.flatten(inputs); | |||
if (!time_major) | |||
{ | |||
flat_input = flat_input.Select(x => ops.convert_to_tensor(x)).ToList(); | |||
//flat_input = flat_input.Select(x => _transpose_batch_time(x)).ToList(); | |||
} | |||
}); | |||
throw new NotImplementedException(""); | |||
} | |||
public static (Tensor, Tensor) moments(Tensor x, | |||
int[] axes, | |||
string name = null, | |||
@@ -1,10 +1,48 @@ | |||
using System; | |||
/***************************************************************************** | |||
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
******************************************************************************/ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Keras.Engine; | |||
using Tensorflow.Operations.Activation; | |||
namespace Tensorflow | |||
{ | |||
public class BasicRNNCell | |||
public class BasicRNNCell : LayerRNNCell | |||
{ | |||
int _num_units; | |||
Func<Tensor, string, Tensor> _activation; | |||
public BasicRNNCell(int num_units, | |||
Func<Tensor, string, Tensor> activation = null, | |||
bool? reuse = null, | |||
string name = null, | |||
TF_DataType dtype = TF_DataType.DtInvalid) : base(_reuse: reuse, | |||
name: name, | |||
dtype: dtype) | |||
{ | |||
// Inputs must be 2-dimensional. | |||
input_spec = new InputSpec(ndim: 2); | |||
_num_units = num_units; | |||
if (activation == null) | |||
_activation = math_ops.tanh; | |||
else | |||
_activation = activation; | |||
} | |||
} | |||
} |
@@ -0,0 +1,33 @@ | |||
/***************************************************************************** | |||
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
******************************************************************************/ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow | |||
{ | |||
public class LayerRNNCell : RNNCell | |||
{ | |||
public LayerRNNCell(bool? _reuse = null, | |||
string name = null, | |||
TF_DataType dtype = TF_DataType.DtInvalid) : base(_reuse: _reuse, | |||
name: name, | |||
dtype: dtype) | |||
{ | |||
} | |||
} | |||
} |
@@ -0,0 +1,63 @@ | |||
/***************************************************************************** | |||
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
******************************************************************************/ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow | |||
{ | |||
/// <summary> | |||
/// Abstract object representing an RNN cell. | |||
/// | |||
/// Every `RNNCell` must have the properties below and implement `call` with | |||
/// the signature `(output, next_state) = call(input, state)`. The optional | |||
/// third input argument, `scope`, is allowed for backwards compatibility | |||
/// purposes; but should be left off for new subclasses. | |||
/// | |||
/// This definition of cell differs from the definition used in the literature. | |||
/// In the literature, 'cell' refers to an object with a single scalar output. | |||
/// This definition refers to a horizontal array of such units. | |||
/// | |||
/// An RNN cell, in the most abstract setting, is anything that has | |||
/// a state and performs some operation that takes a matrix of inputs. | |||
/// This operation results in an output matrix with `self.output_size` columns. | |||
/// If `self.state_size` is an integer, this operation also results in a new | |||
/// state matrix with `self.state_size` columns. If `self.state_size` is a | |||
/// (possibly nested tuple of) TensorShape object(s), then it should return a | |||
/// matching structure of Tensors having shape `[batch_size].concatenate(s)` | |||
/// for each `s` in `self.batch_size`. | |||
/// </summary> | |||
public abstract class RNNCell : Layers.Layer | |||
{ | |||
/// <summary> | |||
/// Attribute that indicates whether the cell is a TF RNN cell, due the slight | |||
/// difference between TF and Keras RNN cell. | |||
/// </summary> | |||
protected bool _is_tf_rnn_cell = false; | |||
public RNNCell(bool trainable = true, | |||
string name = null, | |||
TF_DataType dtype = TF_DataType.DtInvalid, | |||
bool? _reuse = null) : base(trainable: trainable, | |||
name: name, | |||
dtype: dtype, | |||
_reuse: _reuse) | |||
{ | |||
_is_tf_rnn_cell = true; | |||
} | |||
} | |||
} |
@@ -551,6 +551,9 @@ namespace Tensorflow | |||
}); | |||
} | |||
public static Tensor tanh(Tensor x, string name = null) | |||
=> gen_math_ops.tanh(x, name); | |||
public static Tensor truediv(Tensor x, Tensor y, string name = null) | |||
=> _truediv_python3(x, y, name); | |||
@@ -7,8 +7,6 @@ namespace Tensorflow.Operations | |||
public class rnn_cell_impl | |||
{ | |||
public BasicRNNCell BasicRNNCell(int num_units) | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
=> new BasicRNNCell(num_units); | |||
} | |||
} |
@@ -214,14 +214,14 @@ namespace Tensorflow.Util | |||
//# See the swig file (util.i) for documentation. | |||
//flatten = _pywrap_tensorflow.Flatten | |||
public static List<object> flatten(object structure) | |||
public static List<T> flatten<T>(T structure) | |||
{ | |||
var list = new List<object>(); | |||
var list = new List<T>(); | |||
_flatten_recursive(structure, list); | |||
return list; | |||
} | |||
private static void _flatten_recursive(object obj, List<object> list) | |||
private static void _flatten_recursive<T>(T obj, List<T> list) | |||
{ | |||
if (obj is string) | |||
{ | |||
@@ -232,7 +232,7 @@ namespace Tensorflow.Util | |||
{ | |||
var dict = obj as IDictionary; | |||
foreach (var key in _sorted(dict)) | |||
_flatten_recursive(dict[key], list); | |||
_flatten_recursive((T)dict[key], list); | |||
return; | |||
} | |||
if (obj is NDArray) | |||
@@ -244,7 +244,7 @@ namespace Tensorflow.Util | |||
{ | |||
var structure = obj as IEnumerable; | |||
foreach (var child in structure) | |||
_flatten_recursive(child, list); | |||
_flatten_recursive((T)child, list); | |||
return; | |||
} | |||
list.Add(obj); | |||
@@ -25,14 +25,12 @@ using static Tensorflow.Python; | |||
namespace TensorFlowNET.Examples.ImageProcess | |||
{ | |||
/// <summary> | |||
/// Convolutional Neural Network classifier for Hand Written Digits | |||
/// CNN architecture with two convolutional layers, followed by two fully-connected layers at the end. | |||
/// Use Stochastic Gradient Descent (SGD) optimizer. | |||
/// http://www.easy-tensorflow.com/tf-tutorials/convolutional-neural-nets-cnns/cnn1 | |||
/// Recurrent Neural Network for handwritten digits MNIST. | |||
/// https://medium.com/machine-learning-algorithms/mnist-using-recurrent-neural-network-2d070a5915a2 | |||
/// </summary> | |||
public class DigitRecognitionRNN : IExample | |||
{ | |||
public bool Enabled { get; set; } = false; | |||
public bool Enabled { get; set; } = true; | |||
public bool IsImportingGraph { get; set; } = false; | |||
public string Name => "MNIST RNN"; | |||
@@ -84,6 +82,7 @@ namespace TensorFlowNET.Examples.ImageProcess | |||
var X = tf.placeholder(tf.float32, new[] { -1, n_steps, n_inputs }); | |||
var y = tf.placeholder(tf.int32, new[] { -1 }); | |||
var cell = tf.nn.rnn_cell.BasicRNNCell(num_units: n_neurons); | |||
var (output, state) = tf.nn.dynamic_rnn(cell, X, dtype: tf.float32); | |||
return graph; | |||
} | |||
@@ -154,6 +153,7 @@ namespace TensorFlowNET.Examples.ImageProcess | |||
print("Size of:"); | |||
print($"- Training-set:\t\t{len(mnist.train.data)}"); | |||
print($"- Validation-set:\t{len(mnist.validation.data)}"); | |||
print($"- Test-set:\t\t{len(mnist.test.data)}"); | |||
} | |||
public Graph ImportGraph() => throw new NotImplementedException(); | |||
@@ -78,8 +78,8 @@ namespace TensorFlowNET.UnitTest.nest_test | |||
self.assertEqual((((restructured_from_flat[1] as object[])[0] as object[])[0] as Hashtable)["y"], 0); | |||
self.assertEqual(new List<object> { 5 }, nest.flatten(5)); | |||
flat = nest.flatten(np.array(new[] { 5 })); | |||
self.assertEqual(new object[] { np.array(new int[] { 5 }) }, flat); | |||
var flat1 = nest.flatten(np.array(new[] { 5 })); | |||
self.assertEqual(new object[] { np.array(new int[] { 5 }) }, flat1); | |||
self.assertEqual("a", nest.pack_sequence_as(5, new List<object> { "a" })); | |||
self.assertEqual(np.array(new[] { 5 }), | |||