diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/ReshapeArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/ReshapeArgs.cs
new file mode 100644
index 00000000..83cdb28a
--- /dev/null
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/ReshapeArgs.cs
@@ -0,0 +1,7 @@
+namespace Tensorflow.Keras.ArgsDefinition
+{
+ public class ReshapeArgs : LayerArgs
+ {
+ public TensorShape TargetShape { get; set; }
+ }
+}
diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.Reshaping.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.Reshaping.cs
index fa69d854..c0bfa321 100644
--- a/src/TensorFlowNET.Keras/Layers/LayersApi.Reshaping.cs
+++ b/src/TensorFlowNET.Keras/Layers/LayersApi.Reshaping.cs
@@ -34,5 +34,16 @@ namespace Tensorflow.Keras.Layers
{
Size = size ?? (2, 2)
});
+
+ ///
+ /// Layer that reshapes inputs into the given shape.
+ ///
+ ///
+ ///
+ public Reshape Reshape(TensorShape target_shape)
+ => new Reshape(new ReshapeArgs
+ {
+ TargetShape = target_shape
+ });
}
}
diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.cs
index b638cc70..5a41c76e 100644
--- a/src/TensorFlowNET.Keras/Layers/LayersApi.cs
+++ b/src/TensorFlowNET.Keras/Layers/LayersApi.cs
@@ -372,8 +372,8 @@ namespace Tensorflow.Keras.Layers
InputShape = input_shape
});
- public Add Add(params Tensor[] inputs)
- => new Add(new MergeArgs { Inputs = inputs });
+ public Add Add()
+ => new Add(new MergeArgs { });
public GlobalAveragePooling2D GlobalAveragePooling2D()
=> new GlobalAveragePooling2D(new Pooling2DArgs { });
diff --git a/src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs
new file mode 100644
index 00000000..687bcafe
--- /dev/null
+++ b/src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs
@@ -0,0 +1,34 @@
+using Tensorflow.Keras.ArgsDefinition;
+using Tensorflow.Keras.Engine;
+using static Tensorflow.KerasApi;
+using static Tensorflow.Binding;
+using System.Collections.Generic;
+using System;
+
+namespace Tensorflow.Keras.Layers
+{
+ ///
+ /// Layer that reshapes inputs into the given shape.
+ ///
+ public class Reshape : Layer
+ {
+ ReshapeArgs args;
+ public Reshape(ReshapeArgs args)
+ : base(args)
+ {
+ this.args = args;
+ }
+
+ protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false)
+ {
+ var shape = new List { inputs.shape[0] };
+ shape.AddRange(args.TargetShape.dims);
+
+ var result = array_ops.reshape(inputs, shape.ToArray());
+ if (!tf.Context.executing_eagerly())
+ // result = result.set_shape(compute_output_shape(inputs.shape));
+ throw new NotImplementedException("");
+ return result;
+ }
+ }
+}
diff --git a/test/TensorFlowNET.UnitTest/Keras/Layers.Reshaping.Test.cs b/test/TensorFlowNET.UnitTest/Keras/Layers.Reshaping.Test.cs
index fc5bb525..6ce816d6 100644
--- a/test/TensorFlowNET.UnitTest/Keras/Layers.Reshaping.Test.cs
+++ b/test/TensorFlowNET.UnitTest/Keras/Layers.Reshaping.Test.cs
@@ -1,6 +1,6 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using NumSharp;
-using Tensorflow;
+using static Tensorflow.Binding;
using static Tensorflow.KerasApi;
namespace TensorFlowNET.UnitTest.Keras
@@ -26,5 +26,14 @@ namespace TensorFlowNET.UnitTest.Keras
var y = keras.layers.UpSampling2D(size: (1, 2)).Apply(x);
Assert.AreEqual((2, 2, 2, 3), y.shape);
}
+
+ [TestMethod]
+ public void Reshape()
+ {
+ var inputs = tf.zeros((10, 5, 20));
+ var outputs = keras.layers.LeakyReLU().Apply(inputs);
+ outputs = keras.layers.Reshape((20, 5)).Apply(outputs);
+ Assert.AreEqual((10, 20, 5), outputs.shape);
+ }
}
}