diff --git a/src/TensorFlowNET.Core/APIs/keras.layers.cs b/src/TensorFlowNET.Core/APIs/keras.layers.cs index 016c883e..54a1032f 100644 --- a/src/TensorFlowNET.Core/APIs/keras.layers.cs +++ b/src/TensorFlowNET.Core/APIs/keras.layers.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Text; using Tensorflow.Keras; using Tensorflow.Keras.Engine; @@ -12,10 +13,30 @@ namespace Tensorflow public static class layers { public static Embedding Embedding(int input_dim, int output_dim, - string embeddings_initializer = "uniform", - bool mask_zero = false) => new Embedding(input_dim, output_dim, - embeddings_initializer, - mask_zero); + IInitializer embeddings_initializer = null, + bool mask_zero = false) => new Embedding(input_dim, output_dim, + embeddings_initializer, + mask_zero); + + public static InputLayer Input(int[] batch_shape = null, + TF_DataType dtype = TF_DataType.DtInvalid, + string name = null, + bool sparse = false, + Tensor tensor = null) + { + var batch_size = batch_shape[0]; + var shape = batch_shape.Skip(1).ToArray(); + + var input_layer = new InputLayer( + input_shape: shape, + batch_size: batch_size, + name: name, + dtype: dtype, + sparse: sparse, + input_tensor: tensor); + + throw new NotImplementedException(""); + } } } } diff --git a/src/TensorFlowNET.Core/Keras/Engine/Network.cs b/src/TensorFlowNET.Core/Keras/Engine/Network.cs index e50d7dd9..e06d4e05 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Network.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Network.cs @@ -10,11 +10,16 @@ namespace Tensorflow.Keras.Engine protected bool _is_compiled; protected bool _expects_training_arg; protected bool _compute_output_and_mask_jointly; + /// + /// All layers in order of horizontal graph traversal. + /// Entries are unique. Includes input and output layers. + /// + protected List _layers; public Network(string name = null) : base(name: name) { - + _init_subclassed_network(name); } protected virtual void _init_subclassed_network(string name = null) @@ -30,6 +35,7 @@ namespace Tensorflow.Keras.Engine _expects_training_arg = false; _compute_output_and_mask_jointly = false; supports_masking = false; + _layers = new List(); } } } diff --git a/src/TensorFlowNET.Core/Keras/Engine/Sequential.cs b/src/TensorFlowNET.Core/Keras/Engine/Sequential.cs index a83c06e3..0801cdec 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Sequential.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Sequential.cs @@ -23,6 +23,18 @@ namespace Tensorflow.Keras.Engine { built = false; var set_inputs = false; + if(_layers.Count == 0) + { + var (batch_shape, dtype) = (layer._batch_input_shape, layer._dtype); + if(batch_shape != null) + { + // Instantiate an input layer. + var x = keras.layers.Input( + batch_shape: batch_shape, + dtype: dtype, + name: layer._name + "_input"); + } + } } public void __exit__() diff --git a/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs b/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs index c7285def..494cad8b 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs @@ -12,7 +12,9 @@ namespace Tensorflow.Keras.Layers public Embedding(int input_dim, int output_dim, IInitializer embeddings_initializer = null, - bool mask_zero = false) + bool mask_zero = false, + TF_DataType dtype = TF_DataType.TF_FLOAT, + int[] input_shape = null) : base(dtype: dtype, input_shape: input_shape) { this.input_dim = input_dim; this.output_dim = output_dim; diff --git a/src/TensorFlowNET.Core/Keras/Layers/InputLayer.cs b/src/TensorFlowNET.Core/Keras/Layers/InputLayer.cs new file mode 100644 index 00000000..e811776d --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Layers/InputLayer.cs @@ -0,0 +1,45 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.Layers +{ + /// + /// Layer to be used as an entry point into a Network (a graph of layers). + /// + public class InputLayer : Layer + { + public bool sparse; + public int? batch_size; + + public InputLayer(int[] input_shape = null, + int? batch_size = null, + TF_DataType dtype = TF_DataType.DtInvalid, + string name = null, + bool sparse = false, + Tensor input_tensor = null) + { + built = true; + this.sparse = sparse; + this.batch_size = batch_size; + this.supports_masking = true; + + if(input_tensor == null) + { + var batch_input_shape = new int[] { batch_size.HasValue ? batch_size.Value : -1, -1 }; + + if (sparse) + { + throw new NotImplementedException("InputLayer sparse is true"); + } + else + { + input_tensor = backend.placeholder( + shape: batch_input_shape, + dtype: dtype, + name: name); + } + } + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Layers/Layer.cs b/src/TensorFlowNET.Core/Keras/Layers/Layer.cs index 0ed2a4ce..3595087f 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Layer.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Layer.cs @@ -21,7 +21,7 @@ namespace Tensorflow.Keras.Layers /// protected bool built; protected bool trainable; - protected TF_DataType _dtype; + public TF_DataType _dtype; /// /// A stateful layer is a layer whose updates are run during inference too, /// for instance stateful RNNs. @@ -33,12 +33,16 @@ namespace Tensorflow.Keras.Layers protected InputSpec input_spec; protected bool supports_masking; protected List _trainable_weights; - protected string _name; + public string _name; protected string _base_name; protected bool _compute_previous_mask; protected List _updates; + public int[] _batch_input_shape; - public Layer(bool trainable = true, string name = null, TF_DataType dtype = TF_DataType.DtInvalid) + public Layer(bool trainable = true, + string name = null, + TF_DataType dtype = TF_DataType.DtInvalid, + int[] input_shape = null) { this.trainable = trainable; this._dtype = dtype; @@ -49,6 +53,12 @@ namespace Tensorflow.Keras.Layers _trainable_weights = new List(); _compute_previous_mask = false; _updates = new List(); + + // Manage input shape information if passed. + + _batch_input_shape = new int[] { -1, -1 }; + + _dtype = dtype; } public Tensor __call__(Tensor inputs, diff --git a/src/TensorFlowNET.Core/Keras/backend.cs b/src/TensorFlowNET.Core/Keras/backend.cs index 0196bfea..17ab0fbb 100644 --- a/src/TensorFlowNET.Core/Keras/backend.cs +++ b/src/TensorFlowNET.Core/Keras/backend.cs @@ -11,6 +11,22 @@ namespace Tensorflow.Keras } + public static Tensor placeholder(int[] shape = null, + int ndim = -1, + TF_DataType dtype = TF_DataType.DtInvalid, + bool sparse = false, + string name = null) + { + if(sparse) + { + throw new NotImplementedException("placeholder sparse is true"); + } + else + { + return gen_array_ops.placeholder(dtype: dtype, shape: new TensorShape(shape), name: name); + } + } + public static Graph get_graph() { return ops.get_default_graph(); diff --git a/test/TensorFlowNET.Examples/TextClassificationWithMovieReviews.cs b/test/TensorFlowNET.Examples/TextClassification/TextClassificationWithMovieReviews.cs similarity index 100% rename from test/TensorFlowNET.Examples/TextClassificationWithMovieReviews.cs rename to test/TensorFlowNET.Examples/TextClassification/TextClassificationWithMovieReviews.cs