Browse Source

tf.flayers: Added flatten

tags/v0.12
Eli Belash 6 years ago
parent
commit
f7ef39cac8
2 changed files with 74 additions and 0 deletions
  1. +34
    -0
      src/TensorFlowNET.Core/APIs/tf.layers.cs
  2. +40
    -0
      test/TensorFlowNET.UnitTest/layers_test/flatten.cs

+ 34
- 0
src/TensorFlowNET.Core/APIs/tf.layers.cs View File

@@ -14,6 +14,7 @@
limitations under the License.
******************************************************************************/

using System.Collections.Generic;
using Tensorflow.Keras.Layers;
using Tensorflow.Operations.Activation;
using static Tensorflow.Binding;
@@ -163,6 +164,39 @@ namespace Tensorflow

return layer.apply(inputs);
}

/// <summary>
/// Flattens an input tensor while preserving the batch axis (axis 0).
/// </summary>
/// <param name="inputs">Tensor input.</param>
/// <param name="name">The name of the layer.</param>
/// <param name="data_format">
/// A string, one of `channels_last` (default) or `channels_first`. <br></br>
/// The ordering of the dimensions in the inputs. <br></br>
/// `channels_last` corresponds to inputs with shape <br></br>
/// `(batch, height, width, channels)` while `channels_first` corresponds to <br></br>
/// inputs with shape `(batch, channels, height, width)`.
/// </param>
/// <returns></returns>
public Tensor flatten(Tensor inputs,
string name = null,
string data_format = "channels_last")
{
if (inputs.shape.Length == 0)
throw new ValueError($"Input 0 of layer flatten is incompatible with the layer: : expected min_ndim={1}, found ndim={0}. Full shape received: ()");

var premutation = new List<int>() {0};
if (data_format == "channels_first" && inputs.NDims > 1)
{
premutation.AddRange(Binding.range(2, inputs.NDims));
premutation.Add(1);
inputs = array_ops.transpose(inputs, premutation.ToArray());
}

var ret = array_ops.reshape(inputs, new int[] {inputs.shape[0], -1});
ret.set_shape(new int[] {inputs.shape[0], -1});
return ret;
}
}
}
}

+ 40
- 0
test/TensorFlowNET.UnitTest/layers_test/flatten.cs View File

@@ -0,0 +1,40 @@
using System;
using FluentAssertions;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using NumSharp;
using Tensorflow;
using static Tensorflow.Binding;

namespace TensorFlowNET.UnitTest.layers_test
{
[TestClass]
public class flatten
{
[TestMethod]
public void Case1()
{
var sess = tf.Session().as_default();

var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(3, 4, 3, 1, 2));
sess.run(tf.layers.flatten(input), (input, np.arange(3 * 4 * 3 * 1 * 2).reshape(3, 4, 3, 1, 2))).Should().BeShaped(3, 24);
}

[TestMethod]
public void Case2()
{
var sess = tf.Session().as_default();

var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(6));
sess.run(tf.layers.flatten(input), (input, np.arange(6))).Should().BeShaped(6, 1);
}

[TestMethod]
public void Case3()
{
var sess = tf.Session().as_default();

var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape());
new Action(() => sess.run(tf.layers.flatten(input), (input, NDArray.Scalar(6)))).Should().Throw<ValueError>();
}
}
}

Loading…
Cancel
Save