diff --git a/src/TensorFlowNET.Core/APIs/tf.layers.cs b/src/TensorFlowNET.Core/APIs/tf.layers.cs
index 53ddcf13..089dd8a5 100644
--- a/src/TensorFlowNET.Core/APIs/tf.layers.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.layers.cs
@@ -14,6 +14,7 @@
limitations under the License.
******************************************************************************/
+using System.Collections.Generic;
using Tensorflow.Keras.Layers;
using Tensorflow.Operations.Activation;
using static Tensorflow.Binding;
@@ -163,6 +164,39 @@ namespace Tensorflow
return layer.apply(inputs);
}
+
+ ///
+ /// Flattens an input tensor while preserving the batch axis (axis 0).
+ ///
+ /// Tensor input.
+ /// The name of the layer.
+ ///
+ /// A string, one of `channels_last` (default) or `channels_first`.
+ /// The ordering of the dimensions in the inputs.
+ /// `channels_last` corresponds to inputs with shape
+ /// `(batch, height, width, channels)` while `channels_first` corresponds to
+ /// inputs with shape `(batch, channels, height, width)`.
+ ///
+ ///
+ public Tensor flatten(Tensor inputs,
+ string name = null,
+ string data_format = "channels_last")
+ {
+ if (inputs.shape.Length == 0)
+ throw new ValueError($"Input 0 of layer flatten is incompatible with the layer: : expected min_ndim={1}, found ndim={0}. Full shape received: ()");
+
+ var premutation = new List() {0};
+ if (data_format == "channels_first" && inputs.NDims > 1)
+ {
+ premutation.AddRange(Binding.range(2, inputs.NDims));
+ premutation.Add(1);
+ inputs = array_ops.transpose(inputs, premutation.ToArray());
+ }
+
+ var ret = array_ops.reshape(inputs, new int[] {inputs.shape[0], -1});
+ ret.set_shape(new int[] {inputs.shape[0], -1});
+ return ret;
+ }
}
}
}
diff --git a/test/TensorFlowNET.UnitTest/layers_test/flatten.cs b/test/TensorFlowNET.UnitTest/layers_test/flatten.cs
new file mode 100644
index 00000000..d533f128
--- /dev/null
+++ b/test/TensorFlowNET.UnitTest/layers_test/flatten.cs
@@ -0,0 +1,40 @@
+using System;
+using FluentAssertions;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using NumSharp;
+using Tensorflow;
+using static Tensorflow.Binding;
+
+namespace TensorFlowNET.UnitTest.layers_test
+{
+ [TestClass]
+ public class flatten
+ {
+ [TestMethod]
+ public void Case1()
+ {
+ var sess = tf.Session().as_default();
+
+ var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(3, 4, 3, 1, 2));
+ sess.run(tf.layers.flatten(input), (input, np.arange(3 * 4 * 3 * 1 * 2).reshape(3, 4, 3, 1, 2))).Should().BeShaped(3, 24);
+ }
+
+ [TestMethod]
+ public void Case2()
+ {
+ var sess = tf.Session().as_default();
+
+ var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(6));
+ sess.run(tf.layers.flatten(input), (input, np.arange(6))).Should().BeShaped(6, 1);
+ }
+
+ [TestMethod]
+ public void Case3()
+ {
+ var sess = tf.Session().as_default();
+
+ var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape());
+ new Action(() => sess.run(tf.layers.flatten(input), (input, NDArray.Scalar(6)))).Should().Throw();
+ }
+ }
+}
\ No newline at end of file