diff --git a/src/TensorFlowNET.Core/APIs/tf.layers.cs b/src/TensorFlowNET.Core/APIs/tf.layers.cs index 53ddcf13..089dd8a5 100644 --- a/src/TensorFlowNET.Core/APIs/tf.layers.cs +++ b/src/TensorFlowNET.Core/APIs/tf.layers.cs @@ -14,6 +14,7 @@ limitations under the License. ******************************************************************************/ +using System.Collections.Generic; using Tensorflow.Keras.Layers; using Tensorflow.Operations.Activation; using static Tensorflow.Binding; @@ -163,6 +164,39 @@ namespace Tensorflow return layer.apply(inputs); } + + /// + /// Flattens an input tensor while preserving the batch axis (axis 0). + /// + /// Tensor input. + /// The name of the layer. + /// + /// A string, one of `channels_last` (default) or `channels_first`.

+ /// The ordering of the dimensions in the inputs.

+ /// `channels_last` corresponds to inputs with shape

+ /// `(batch, height, width, channels)` while `channels_first` corresponds to

+ /// inputs with shape `(batch, channels, height, width)`. + /// + /// + public Tensor flatten(Tensor inputs, + string name = null, + string data_format = "channels_last") + { + if (inputs.shape.Length == 0) + throw new ValueError($"Input 0 of layer flatten is incompatible with the layer: : expected min_ndim={1}, found ndim={0}. Full shape received: ()"); + + var premutation = new List() {0}; + if (data_format == "channels_first" && inputs.NDims > 1) + { + premutation.AddRange(Binding.range(2, inputs.NDims)); + premutation.Add(1); + inputs = array_ops.transpose(inputs, premutation.ToArray()); + } + + var ret = array_ops.reshape(inputs, new int[] {inputs.shape[0], -1}); + ret.set_shape(new int[] {inputs.shape[0], -1}); + return ret; + } } } } diff --git a/test/TensorFlowNET.UnitTest/layers_test/flatten.cs b/test/TensorFlowNET.UnitTest/layers_test/flatten.cs new file mode 100644 index 00000000..d533f128 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/layers_test/flatten.cs @@ -0,0 +1,40 @@ +using System; +using FluentAssertions; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using NumSharp; +using Tensorflow; +using static Tensorflow.Binding; + +namespace TensorFlowNET.UnitTest.layers_test +{ + [TestClass] + public class flatten + { + [TestMethod] + public void Case1() + { + var sess = tf.Session().as_default(); + + var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(3, 4, 3, 1, 2)); + sess.run(tf.layers.flatten(input), (input, np.arange(3 * 4 * 3 * 1 * 2).reshape(3, 4, 3, 1, 2))).Should().BeShaped(3, 24); + } + + [TestMethod] + public void Case2() + { + var sess = tf.Session().as_default(); + + var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(6)); + sess.run(tf.layers.flatten(input), (input, np.arange(6))).Should().BeShaped(6, 1); + } + + [TestMethod] + public void Case3() + { + var sess = tf.Session().as_default(); + + var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape()); + new Action(() => sess.run(tf.layers.flatten(input), (input, NDArray.Scalar(6)))).Should().Throw(); + } + } +} \ No newline at end of file