diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/ReshapeArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/ReshapeArgs.cs new file mode 100644 index 00000000..83cdb28a --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/ReshapeArgs.cs @@ -0,0 +1,7 @@ +namespace Tensorflow.Keras.ArgsDefinition +{ + public class ReshapeArgs : LayerArgs + { + public TensorShape TargetShape { get; set; } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.Reshaping.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.Reshaping.cs index fa69d854..c0bfa321 100644 --- a/src/TensorFlowNET.Keras/Layers/LayersApi.Reshaping.cs +++ b/src/TensorFlowNET.Keras/Layers/LayersApi.Reshaping.cs @@ -34,5 +34,16 @@ namespace Tensorflow.Keras.Layers { Size = size ?? (2, 2) }); + + /// + /// Layer that reshapes inputs into the given shape. + /// + /// + /// + public Reshape Reshape(TensorShape target_shape) + => new Reshape(new ReshapeArgs + { + TargetShape = target_shape + }); } } diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.cs index b638cc70..5a41c76e 100644 --- a/src/TensorFlowNET.Keras/Layers/LayersApi.cs +++ b/src/TensorFlowNET.Keras/Layers/LayersApi.cs @@ -372,8 +372,8 @@ namespace Tensorflow.Keras.Layers InputShape = input_shape }); - public Add Add(params Tensor[] inputs) - => new Add(new MergeArgs { Inputs = inputs }); + public Add Add() + => new Add(new MergeArgs { }); public GlobalAveragePooling2D GlobalAveragePooling2D() => new GlobalAveragePooling2D(new Pooling2DArgs { }); diff --git a/src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs new file mode 100644 index 00000000..687bcafe --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs @@ -0,0 +1,34 @@ +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using static Tensorflow.KerasApi; +using static Tensorflow.Binding; +using System.Collections.Generic; +using System; + +namespace Tensorflow.Keras.Layers +{ + /// + /// Layer that reshapes inputs into the given shape. + /// + public class Reshape : Layer + { + ReshapeArgs args; + public Reshape(ReshapeArgs args) + : base(args) + { + this.args = args; + } + + protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false) + { + var shape = new List { inputs.shape[0] }; + shape.AddRange(args.TargetShape.dims); + + var result = array_ops.reshape(inputs, shape.ToArray()); + if (!tf.Context.executing_eagerly()) + // result = result.set_shape(compute_output_shape(inputs.shape)); + throw new NotImplementedException(""); + return result; + } + } +} diff --git a/test/TensorFlowNET.UnitTest/Keras/Layers.Reshaping.Test.cs b/test/TensorFlowNET.UnitTest/Keras/Layers.Reshaping.Test.cs index fc5bb525..6ce816d6 100644 --- a/test/TensorFlowNET.UnitTest/Keras/Layers.Reshaping.Test.cs +++ b/test/TensorFlowNET.UnitTest/Keras/Layers.Reshaping.Test.cs @@ -1,6 +1,6 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; using NumSharp; -using Tensorflow; +using static Tensorflow.Binding; using static Tensorflow.KerasApi; namespace TensorFlowNET.UnitTest.Keras @@ -26,5 +26,14 @@ namespace TensorFlowNET.UnitTest.Keras var y = keras.layers.UpSampling2D(size: (1, 2)).Apply(x); Assert.AreEqual((2, 2, 2, 3), y.shape); } + + [TestMethod] + public void Reshape() + { + var inputs = tf.zeros((10, 5, 20)); + var outputs = keras.layers.LeakyReLU().Apply(inputs); + outputs = keras.layers.Reshape((20, 5)).Apply(outputs); + Assert.AreEqual((10, 20, 5), outputs.shape); + } } }