diff --git a/src/TensorFlowNET.Core/APIs/keras.layers.cs b/src/TensorFlowNET.Core/APIs/keras.layers.cs
index 016c883e..54a1032f 100644
--- a/src/TensorFlowNET.Core/APIs/keras.layers.cs
+++ b/src/TensorFlowNET.Core/APIs/keras.layers.cs
@@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
+using System.Linq;
using System.Text;
using Tensorflow.Keras;
using Tensorflow.Keras.Engine;
@@ -12,10 +13,30 @@ namespace Tensorflow
public static class layers
{
public static Embedding Embedding(int input_dim, int output_dim,
- string embeddings_initializer = "uniform",
- bool mask_zero = false) => new Embedding(input_dim, output_dim,
- embeddings_initializer,
- mask_zero);
+ IInitializer embeddings_initializer = null,
+ bool mask_zero = false) => new Embedding(input_dim, output_dim,
+ embeddings_initializer,
+ mask_zero);
+
+ public static InputLayer Input(int[] batch_shape = null,
+ TF_DataType dtype = TF_DataType.DtInvalid,
+ string name = null,
+ bool sparse = false,
+ Tensor tensor = null)
+ {
+ var batch_size = batch_shape[0];
+ var shape = batch_shape.Skip(1).ToArray();
+
+ var input_layer = new InputLayer(
+ input_shape: shape,
+ batch_size: batch_size,
+ name: name,
+ dtype: dtype,
+ sparse: sparse,
+ input_tensor: tensor);
+
+ throw new NotImplementedException("");
+ }
}
}
}
diff --git a/src/TensorFlowNET.Core/Keras/Engine/Network.cs b/src/TensorFlowNET.Core/Keras/Engine/Network.cs
index e50d7dd9..e06d4e05 100644
--- a/src/TensorFlowNET.Core/Keras/Engine/Network.cs
+++ b/src/TensorFlowNET.Core/Keras/Engine/Network.cs
@@ -10,11 +10,16 @@ namespace Tensorflow.Keras.Engine
protected bool _is_compiled;
protected bool _expects_training_arg;
protected bool _compute_output_and_mask_jointly;
+ ///
+ /// All layers in order of horizontal graph traversal.
+ /// Entries are unique. Includes input and output layers.
+ ///
+ protected List _layers;
public Network(string name = null)
: base(name: name)
{
-
+ _init_subclassed_network(name);
}
protected virtual void _init_subclassed_network(string name = null)
@@ -30,6 +35,7 @@ namespace Tensorflow.Keras.Engine
_expects_training_arg = false;
_compute_output_and_mask_jointly = false;
supports_masking = false;
+ _layers = new List();
}
}
}
diff --git a/src/TensorFlowNET.Core/Keras/Engine/Sequential.cs b/src/TensorFlowNET.Core/Keras/Engine/Sequential.cs
index a83c06e3..0801cdec 100644
--- a/src/TensorFlowNET.Core/Keras/Engine/Sequential.cs
+++ b/src/TensorFlowNET.Core/Keras/Engine/Sequential.cs
@@ -23,6 +23,18 @@ namespace Tensorflow.Keras.Engine
{
built = false;
var set_inputs = false;
+ if(_layers.Count == 0)
+ {
+ var (batch_shape, dtype) = (layer._batch_input_shape, layer._dtype);
+ if(batch_shape != null)
+ {
+ // Instantiate an input layer.
+ var x = keras.layers.Input(
+ batch_shape: batch_shape,
+ dtype: dtype,
+ name: layer._name + "_input");
+ }
+ }
}
public void __exit__()
diff --git a/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs b/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs
index c7285def..494cad8b 100644
--- a/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs
+++ b/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs
@@ -12,7 +12,9 @@ namespace Tensorflow.Keras.Layers
public Embedding(int input_dim, int output_dim,
IInitializer embeddings_initializer = null,
- bool mask_zero = false)
+ bool mask_zero = false,
+ TF_DataType dtype = TF_DataType.TF_FLOAT,
+ int[] input_shape = null) : base(dtype: dtype, input_shape: input_shape)
{
this.input_dim = input_dim;
this.output_dim = output_dim;
diff --git a/src/TensorFlowNET.Core/Keras/Layers/InputLayer.cs b/src/TensorFlowNET.Core/Keras/Layers/InputLayer.cs
new file mode 100644
index 00000000..e811776d
--- /dev/null
+++ b/src/TensorFlowNET.Core/Keras/Layers/InputLayer.cs
@@ -0,0 +1,45 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow.Keras.Layers
+{
+ ///
+ /// Layer to be used as an entry point into a Network (a graph of layers).
+ ///
+ public class InputLayer : Layer
+ {
+ public bool sparse;
+ public int? batch_size;
+
+ public InputLayer(int[] input_shape = null,
+ int? batch_size = null,
+ TF_DataType dtype = TF_DataType.DtInvalid,
+ string name = null,
+ bool sparse = false,
+ Tensor input_tensor = null)
+ {
+ built = true;
+ this.sparse = sparse;
+ this.batch_size = batch_size;
+ this.supports_masking = true;
+
+ if(input_tensor == null)
+ {
+ var batch_input_shape = new int[] { batch_size.HasValue ? batch_size.Value : -1, -1 };
+
+ if (sparse)
+ {
+ throw new NotImplementedException("InputLayer sparse is true");
+ }
+ else
+ {
+ input_tensor = backend.placeholder(
+ shape: batch_input_shape,
+ dtype: dtype,
+ name: name);
+ }
+ }
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Keras/Layers/Layer.cs b/src/TensorFlowNET.Core/Keras/Layers/Layer.cs
index 0ed2a4ce..3595087f 100644
--- a/src/TensorFlowNET.Core/Keras/Layers/Layer.cs
+++ b/src/TensorFlowNET.Core/Keras/Layers/Layer.cs
@@ -21,7 +21,7 @@ namespace Tensorflow.Keras.Layers
///
protected bool built;
protected bool trainable;
- protected TF_DataType _dtype;
+ public TF_DataType _dtype;
///
/// A stateful layer is a layer whose updates are run during inference too,
/// for instance stateful RNNs.
@@ -33,12 +33,16 @@ namespace Tensorflow.Keras.Layers
protected InputSpec input_spec;
protected bool supports_masking;
protected List _trainable_weights;
- protected string _name;
+ public string _name;
protected string _base_name;
protected bool _compute_previous_mask;
protected List _updates;
+ public int[] _batch_input_shape;
- public Layer(bool trainable = true, string name = null, TF_DataType dtype = TF_DataType.DtInvalid)
+ public Layer(bool trainable = true,
+ string name = null,
+ TF_DataType dtype = TF_DataType.DtInvalid,
+ int[] input_shape = null)
{
this.trainable = trainable;
this._dtype = dtype;
@@ -49,6 +53,12 @@ namespace Tensorflow.Keras.Layers
_trainable_weights = new List();
_compute_previous_mask = false;
_updates = new List();
+
+ // Manage input shape information if passed.
+
+ _batch_input_shape = new int[] { -1, -1 };
+
+ _dtype = dtype;
}
public Tensor __call__(Tensor inputs,
diff --git a/src/TensorFlowNET.Core/Keras/backend.cs b/src/TensorFlowNET.Core/Keras/backend.cs
index 0196bfea..17ab0fbb 100644
--- a/src/TensorFlowNET.Core/Keras/backend.cs
+++ b/src/TensorFlowNET.Core/Keras/backend.cs
@@ -11,6 +11,22 @@ namespace Tensorflow.Keras
}
+ public static Tensor placeholder(int[] shape = null,
+ int ndim = -1,
+ TF_DataType dtype = TF_DataType.DtInvalid,
+ bool sparse = false,
+ string name = null)
+ {
+ if(sparse)
+ {
+ throw new NotImplementedException("placeholder sparse is true");
+ }
+ else
+ {
+ return gen_array_ops.placeholder(dtype: dtype, shape: new TensorShape(shape), name: name);
+ }
+ }
+
public static Graph get_graph()
{
return ops.get_default_graph();
diff --git a/test/TensorFlowNET.Examples/TextClassificationWithMovieReviews.cs b/test/TensorFlowNET.Examples/TextClassification/TextClassificationWithMovieReviews.cs
similarity index 100%
rename from test/TensorFlowNET.Examples/TextClassificationWithMovieReviews.cs
rename to test/TensorFlowNET.Examples/TextClassification/TextClassificationWithMovieReviews.cs