Browse Source

Add BasicRNNCell, LayerRNNCell, RNNCell,

Change nest.flatten to generic.
tags/v0.10
Oceania2018 6 years ago
parent
commit
ee0b935557
10 changed files with 183 additions and 18 deletions
  1. +2
    -1
      docs/source/ConvolutionNeuralNetwork.md
  2. +29
    -0
      src/TensorFlowNET.Core/APIs/tf.nn.cs
  3. +40
    -2
      src/TensorFlowNET.Core/Operations/BasicRNNCell.cs
  4. +33
    -0
      src/TensorFlowNET.Core/Operations/LayerRNNCell.cs
  5. +63
    -0
      src/TensorFlowNET.Core/Operations/RNNCell.cs
  6. +3
    -0
      src/TensorFlowNET.Core/Operations/math_ops.cs
  7. +1
    -3
      src/TensorFlowNET.Core/Operations/rnn_cell_impl.cs
  8. +5
    -5
      src/TensorFlowNET.Core/Util/nest.py.cs
  9. +5
    -5
      test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionRNN.cs
  10. +2
    -2
      test/TensorFlowNET.UnitTest/nest_test/NestTest.cs

+ 2
- 1
docs/source/ConvolutionNeuralNetwork.md View File

@@ -346,4 +346,5 @@ Get started with the implementation:
}
```

![cnn-reuslt](../assets/cnn-result.png)
![](../assets/cnn-result.png)


+ 29
- 0
src/TensorFlowNET.Core/APIs/tf.nn.cs View File

@@ -16,9 +16,11 @@

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Operations;
using Tensorflow.Operations.Activation;
using Tensorflow.Util;
using static Tensorflow.Python;

namespace Tensorflow
@@ -68,6 +70,33 @@ namespace Tensorflow
return nn_ops.dropout_v2(x, rate: rate_tensor, noise_shape: noise_shape, seed: seed, name: name);
}

/// <summary>
/// Creates a recurrent neural network specified by RNNCell `cell`.
/// </summary>
/// <param name="cell">An instance of RNNCell.</param>
/// <param name="inputs">The RNN inputs.</param>
/// <param name="dtype"></param>
/// <param name="swap_memory"></param>
/// <param name="time_major"></param>
/// <returns>A pair (outputs, state)</returns>
public static (Tensor, Tensor) dynamic_rnn(RNNCell cell, Tensor inputs, TF_DataType dtype = TF_DataType.DtInvalid,
bool swap_memory = false, bool time_major = false)
{
with(variable_scope("rnn"), scope =>
{
VariableScope varscope = scope;
var flat_input = nest.flatten(inputs);

if (!time_major)
{
flat_input = flat_input.Select(x => ops.convert_to_tensor(x)).ToList();
//flat_input = flat_input.Select(x => _transpose_batch_time(x)).ToList();
}
});

throw new NotImplementedException("");
}

public static (Tensor, Tensor) moments(Tensor x,
int[] axes,
string name = null,


+ 40
- 2
src/TensorFlowNET.Core/Operations/BasicRNNCell.cs View File

@@ -1,10 +1,48 @@
using System;
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.Engine;
using Tensorflow.Operations.Activation;

namespace Tensorflow
{
public class BasicRNNCell
public class BasicRNNCell : LayerRNNCell
{
int _num_units;
Func<Tensor, string, Tensor> _activation;

public BasicRNNCell(int num_units,
Func<Tensor, string, Tensor> activation = null,
bool? reuse = null,
string name = null,
TF_DataType dtype = TF_DataType.DtInvalid) : base(_reuse: reuse,
name: name,
dtype: dtype)
{
// Inputs must be 2-dimensional.
input_spec = new InputSpec(ndim: 2);

_num_units = num_units;
if (activation == null)
_activation = math_ops.tanh;
else
_activation = activation;
}
}
}

+ 33
- 0
src/TensorFlowNET.Core/Operations/LayerRNNCell.cs View File

@@ -0,0 +1,33 @@
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public class LayerRNNCell : RNNCell
{
public LayerRNNCell(bool? _reuse = null,
string name = null,
TF_DataType dtype = TF_DataType.DtInvalid) : base(_reuse: _reuse,
name: name,
dtype: dtype)
{
}
}
}

+ 63
- 0
src/TensorFlowNET.Core/Operations/RNNCell.cs View File

@@ -0,0 +1,63 @@
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
/// <summary>
/// Abstract object representing an RNN cell.
///
/// Every `RNNCell` must have the properties below and implement `call` with
/// the signature `(output, next_state) = call(input, state)`. The optional
/// third input argument, `scope`, is allowed for backwards compatibility
/// purposes; but should be left off for new subclasses.
///
/// This definition of cell differs from the definition used in the literature.
/// In the literature, 'cell' refers to an object with a single scalar output.
/// This definition refers to a horizontal array of such units.
///
/// An RNN cell, in the most abstract setting, is anything that has
/// a state and performs some operation that takes a matrix of inputs.
/// This operation results in an output matrix with `self.output_size` columns.
/// If `self.state_size` is an integer, this operation also results in a new
/// state matrix with `self.state_size` columns. If `self.state_size` is a
/// (possibly nested tuple of) TensorShape object(s), then it should return a
/// matching structure of Tensors having shape `[batch_size].concatenate(s)`
/// for each `s` in `self.batch_size`.
/// </summary>
public abstract class RNNCell : Layers.Layer
{
/// <summary>
/// Attribute that indicates whether the cell is a TF RNN cell, due the slight
/// difference between TF and Keras RNN cell.
/// </summary>
protected bool _is_tf_rnn_cell = false;

public RNNCell(bool trainable = true,
string name = null,
TF_DataType dtype = TF_DataType.DtInvalid,
bool? _reuse = null) : base(trainable: trainable,
name: name,
dtype: dtype,
_reuse: _reuse)
{
_is_tf_rnn_cell = true;
}
}
}

+ 3
- 0
src/TensorFlowNET.Core/Operations/math_ops.cs View File

@@ -551,6 +551,9 @@ namespace Tensorflow
});
}

public static Tensor tanh(Tensor x, string name = null)
=> gen_math_ops.tanh(x, name);

public static Tensor truediv(Tensor x, Tensor y, string name = null)
=> _truediv_python3(x, y, name);



+ 1
- 3
src/TensorFlowNET.Core/Operations/rnn_cell_impl.cs View File

@@ -7,8 +7,6 @@ namespace Tensorflow.Operations
public class rnn_cell_impl
{
public BasicRNNCell BasicRNNCell(int num_units)
{
throw new NotImplementedException();
}
=> new BasicRNNCell(num_units);
}
}

+ 5
- 5
src/TensorFlowNET.Core/Util/nest.py.cs View File

@@ -214,14 +214,14 @@ namespace Tensorflow.Util
//# See the swig file (util.i) for documentation.
//flatten = _pywrap_tensorflow.Flatten
public static List<object> flatten(object structure)
public static List<T> flatten<T>(T structure)
{
var list = new List<object>();
var list = new List<T>();
_flatten_recursive(structure, list);
return list;
}
private static void _flatten_recursive(object obj, List<object> list)
private static void _flatten_recursive<T>(T obj, List<T> list)
{
if (obj is string)
{
@@ -232,7 +232,7 @@ namespace Tensorflow.Util
{
var dict = obj as IDictionary;
foreach (var key in _sorted(dict))
_flatten_recursive(dict[key], list);
_flatten_recursive((T)dict[key], list);
return;
}
if (obj is NDArray)
@@ -244,7 +244,7 @@ namespace Tensorflow.Util
{
var structure = obj as IEnumerable;
foreach (var child in structure)
_flatten_recursive(child, list);
_flatten_recursive((T)child, list);
return;
}
list.Add(obj);


+ 5
- 5
test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionRNN.cs View File

@@ -25,14 +25,12 @@ using static Tensorflow.Python;
namespace TensorFlowNET.Examples.ImageProcess
{
/// <summary>
/// Convolutional Neural Network classifier for Hand Written Digits
/// CNN architecture with two convolutional layers, followed by two fully-connected layers at the end.
/// Use Stochastic Gradient Descent (SGD) optimizer.
/// http://www.easy-tensorflow.com/tf-tutorials/convolutional-neural-nets-cnns/cnn1
/// Recurrent Neural Network for handwritten digits MNIST.
/// https://medium.com/machine-learning-algorithms/mnist-using-recurrent-neural-network-2d070a5915a2
/// </summary>
public class DigitRecognitionRNN : IExample
{
public bool Enabled { get; set; } = false;
public bool Enabled { get; set; } = true;
public bool IsImportingGraph { get; set; } = false;

public string Name => "MNIST RNN";
@@ -84,6 +82,7 @@ namespace TensorFlowNET.Examples.ImageProcess
var X = tf.placeholder(tf.float32, new[] { -1, n_steps, n_inputs });
var y = tf.placeholder(tf.int32, new[] { -1 });
var cell = tf.nn.rnn_cell.BasicRNNCell(num_units: n_neurons);
var (output, state) = tf.nn.dynamic_rnn(cell, X, dtype: tf.float32);

return graph;
}
@@ -154,6 +153,7 @@ namespace TensorFlowNET.Examples.ImageProcess
print("Size of:");
print($"- Training-set:\t\t{len(mnist.train.data)}");
print($"- Validation-set:\t{len(mnist.validation.data)}");
print($"- Test-set:\t\t{len(mnist.test.data)}");
}

public Graph ImportGraph() => throw new NotImplementedException();


+ 2
- 2
test/TensorFlowNET.UnitTest/nest_test/NestTest.cs View File

@@ -78,8 +78,8 @@ namespace TensorFlowNET.UnitTest.nest_test
self.assertEqual((((restructured_from_flat[1] as object[])[0] as object[])[0] as Hashtable)["y"], 0);
self.assertEqual(new List<object> { 5 }, nest.flatten(5));
flat = nest.flatten(np.array(new[] { 5 }));
self.assertEqual(new object[] { np.array(new int[] { 5 }) }, flat);
var flat1 = nest.flatten(np.array(new[] { 5 }));
self.assertEqual(new object[] { np.array(new int[] { 5 }) }, flat1);
self.assertEqual("a", nest.pack_sequence_as(5, new List<object> { "a" }));
self.assertEqual(np.array(new[] { 5 }),


Loading…
Cancel
Save