diff --git a/TensorFlow.NET.sln b/TensorFlow.NET.sln
index b96f8203..12b93519 100644
--- a/TensorFlow.NET.sln
+++ b/TensorFlow.NET.sln
@@ -1,7 +1,7 @@
Microsoft Visual Studio Solution File, Format Version 12.00
-# Visual Studio Version 16
-VisualStudioVersion = 16.0.28803.156
+# Visual Studio 15
+VisualStudioVersion = 15.0.28307.168
MinimumVisualStudioVersion = 10.0.40219.1
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.UnitTest", "test\TensorFlowNET.UnitTest\TensorFlowNET.UnitTest.csproj", "{029A8CF1-CF95-4DCB-98AA-9D3D96A83B3E}"
EndProject
@@ -9,6 +9,12 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Examples", "t
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Core", "src\TensorFlowNET.Core\TensorFlowNET.Core.csproj", "{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}"
EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Keras.Core", "src\KerasNET.Core\Keras.Core.csproj", "{902E188F-A953-43B4-9991-72BAB1697BC3}"
+EndProject
+Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Keras.Example", "test\KerasNET.Example\Keras.Example.csproj", "{17E1AC16-9E0E-4545-905A-E92C6300C7AF}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Keras.UnitTest", "test\KerasNET.Test\Keras.UnitTest.csproj", "{A5839A45-A117-4BEA-898B-DE1ED6E0D58F}"
+EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
@@ -27,6 +33,18 @@ Global
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug|Any CPU.Build.0 = Debug|Any CPU
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|Any CPU.ActiveCfg = Release|Any CPU
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|Any CPU.Build.0 = Release|Any CPU
+ {902E188F-A953-43B4-9991-72BAB1697BC3}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {902E188F-A953-43B4-9991-72BAB1697BC3}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {902E188F-A953-43B4-9991-72BAB1697BC3}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {902E188F-A953-43B4-9991-72BAB1697BC3}.Release|Any CPU.Build.0 = Release|Any CPU
+ {17E1AC16-9E0E-4545-905A-E92C6300C7AF}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {17E1AC16-9E0E-4545-905A-E92C6300C7AF}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {17E1AC16-9E0E-4545-905A-E92C6300C7AF}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {17E1AC16-9E0E-4545-905A-E92C6300C7AF}.Release|Any CPU.Build.0 = Release|Any CPU
+ {A5839A45-A117-4BEA-898B-DE1ED6E0D58F}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {A5839A45-A117-4BEA-898B-DE1ED6E0D58F}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {A5839A45-A117-4BEA-898B-DE1ED6E0D58F}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {A5839A45-A117-4BEA-898B-DE1ED6E0D58F}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
diff --git a/src/KerasNET.Core/Core.cs b/src/KerasNET.Core/Core.cs
new file mode 100644
index 00000000..8adae938
--- /dev/null
+++ b/src/KerasNET.Core/Core.cs
@@ -0,0 +1,15 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+using Tensorflow;
+
+namespace Keras
+{
+ public static class Keras
+ {
+ public static Tensor create_tensor(int[] shape, float mean = 0, float stddev = 1, TF_DataType dtype = TF_DataType.TF_FLOAT, int? seed = null, string name = null)
+ {
+ return tf.truncated_normal(shape: shape, mean: mean, stddev: stddev, dtype: dtype, seed: seed, name: name);
+ }
+ }
+}
diff --git a/src/KerasNET.Core/IInitializer.cs b/src/KerasNET.Core/IInitializer.cs
new file mode 100644
index 00000000..53cb9112
--- /dev/null
+++ b/src/KerasNET.Core/IInitializer.cs
@@ -0,0 +1,10 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Keras
+{
+ interface IInitializer
+ {
+ }
+}
diff --git a/src/KerasNET.Core/Initializer/BaseInitializer.cs b/src/KerasNET.Core/Initializer/BaseInitializer.cs
new file mode 100644
index 00000000..84a420a7
--- /dev/null
+++ b/src/KerasNET.Core/Initializer/BaseInitializer.cs
@@ -0,0 +1,13 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+using Tensorflow;
+using Tensorflow.Layers;
+
+namespace Keras.Initializer
+{
+ class BaseInitializer : IInitializer
+ {
+ public int seed;
+ }
+}
diff --git a/src/KerasNET.Core/Keras.Core.csproj b/src/KerasNET.Core/Keras.Core.csproj
new file mode 100644
index 00000000..3a7cfc6b
--- /dev/null
+++ b/src/KerasNET.Core/Keras.Core.csproj
@@ -0,0 +1,13 @@
+
+
+
+ netstandard2.0
+ Keras
+ Keras
+
+
+
+
+
+
+
diff --git a/src/KerasNET.Core/Layers/Dense.cs b/src/KerasNET.Core/Layers/Dense.cs
new file mode 100644
index 00000000..66569882
--- /dev/null
+++ b/src/KerasNET.Core/Layers/Dense.cs
@@ -0,0 +1,57 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+using System.Linq;
+using Tensorflow;
+using static Keras.Keras;
+using Keras;
+using NumSharp;
+using Tensorflow.Operations.Activation;
+
+namespace Keras.Layers
+{
+ public class Dense : ILayer
+ {
+ RefVariable W;
+ int units;
+ TensorShape WShape;
+ string name;
+ IActivation activation;
+
+ public Dense(int units, string name = null, IActivation activation = null)
+ {
+ this.activation = activation;
+ this.units = units;
+ this.name = (string.IsNullOrEmpty(name) || string.IsNullOrWhiteSpace(name))?this.GetType().Name + "_" + this.GetType().GUID:name;
+ }
+ public ILayer __build__(TensorShape input_shape, int seed = 1, float stddev = -1f)
+ {
+ Console.WriteLine("Building Layer \"" + name + "\" ...");
+ if (stddev == -1)
+ stddev = (float)(1 / Math.Sqrt(2));
+ var dim = input_shape.Dimensions;
+ var input_dim = dim[dim.Length - 1];
+ W = tf.Variable(create_tensor(new int[] { input_dim, units }, seed: seed, stddev: (float)stddev));
+ WShape = new TensorShape(W.shape);
+ return this;
+ }
+ public Tensor __call__(Tensor x)
+ {
+ var dot = tf.matmul(x, W);
+ if (this.activation != null)
+ dot = activation.Activate(dot);
+ Console.WriteLine("Calling Layer \"" + name + "(" + np.array(dot.getShape().Dimensions).ToString() + ")\" ...");
+ return dot;
+ }
+ public TensorShape __shape__()
+ {
+ return WShape;
+ }
+ public TensorShape output_shape(TensorShape input_shape)
+ {
+ var output_shape = input_shape.Dimensions;
+ output_shape[output_shape.Length - 1] = units;
+ return new TensorShape(output_shape);
+ }
+ }
+}
diff --git a/src/KerasNET.Core/Layers/ILayer.cs b/src/KerasNET.Core/Layers/ILayer.cs
new file mode 100644
index 00000000..2c033eef
--- /dev/null
+++ b/src/KerasNET.Core/Layers/ILayer.cs
@@ -0,0 +1,16 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+using Tensorflow;
+using NumSharp;
+
+namespace Keras.Layers
+{
+ public interface ILayer
+ {
+ TensorShape __shape__();
+ ILayer __build__(TensorShape input_shape, int seed = 1, float stddev = -1f);
+ Tensor __call__(Tensor x);
+ TensorShape output_shape(TensorShape input_shape);
+ }
+}
diff --git a/src/KerasNET.Core/Model.cs b/src/KerasNET.Core/Model.cs
new file mode 100644
index 00000000..d1d889fc
--- /dev/null
+++ b/src/KerasNET.Core/Model.cs
@@ -0,0 +1,127 @@
+using Keras.Layers;
+using NumSharp;
+using System;
+using System.Collections.Generic;
+using System.Text;
+using Tensorflow;
+using static Keras.Keras;
+using static Tensorflow.Python;
+
+namespace Keras
+{
+ public class Model
+ {
+ public Tensor Flow;
+ List layer_stack;
+
+ public TensorShape InputShape;
+
+ public Model()
+ {
+ layer_stack = new List();
+ }
+ public Model Add(ILayer layer)
+ {
+ layer_stack.Add(layer);
+ return this;
+ }
+ public Model Add(IEnumerable layers)
+ {
+ layer_stack.AddRange(layers);
+ return this;
+ }
+ public Tensor getFlow()
+ {
+ try
+ {
+ return Flow;
+ }
+ catch (Exception ex)
+ {
+ return null;
+ }
+ }
+ public (Operation, Tensor, Tensor) make_graph(Tensor features, Tensor labels)
+ {
+
+ // TODO : Creating Loss Functions And Optimizers.....
+
+ #region Model Layers Graph
+ /*
+ var stddev = 1 / Math.Sqrt(2);
+
+ var d1 = new Dense(num_hidden);
+ d1.__build__(features.getShape());
+ var hidden_activations = tf.nn.relu(d1.__call__(features));
+
+ var d1_output = d1.output_shape(features.getShape());
+
+
+ var d2 = new Dense(1);
+ d2.__build__(d1.output_shape(features.getShape()), seed: 17, stddev: (float)(1/ Math.Sqrt(num_hidden)));
+ var logits = d2.__call__(hidden_activations);
+ var predictions = tf.sigmoid(tf.squeeze(logits));
+ */
+ #endregion
+
+ #region Model Graph Form Layer Stack
+ var flow_shape = features.getShape();
+ Flow = features;
+ for (int i = 0; i < layer_stack.Count; i++)
+ {
+ layer_stack[i].__build__(flow_shape);
+ flow_shape = layer_stack[i].output_shape(flow_shape);
+ Flow = layer_stack[i].__call__(Flow);
+ }
+ var predictions = tf.sigmoid(tf.squeeze(Flow));
+
+ #endregion
+
+ #region loss and optimizer
+ var loss = tf.reduce_mean(tf.square(predictions - tf.cast(labels, tf.float32)), name: "loss");
+
+ var gs = tf.Variable(0, trainable: false, name: "global_step");
+ var train_op = tf.train.GradientDescentOptimizer(0.2f).minimize(loss, global_step: gs);
+ #endregion
+
+ return (train_op, loss, gs);
+ }
+ public float train(int num_steps, (NDArray, NDArray) training_dataset)
+ {
+ var (X, Y) = training_dataset;
+ var x_shape = X.shape;
+ var batch_size = x_shape[0];
+ var graph = tf.Graph().as_default();
+
+ var features = tf.placeholder(tf.float32, new TensorShape(batch_size, 2));
+ var labels = tf.placeholder(tf.float32, new TensorShape(batch_size));
+
+ var (train_op, loss, gs) = this.make_graph(features, labels);
+
+ var init = tf.global_variables_initializer();
+
+ float loss_value = 0;
+ with(tf.Session(graph), sess =>
+ {
+ sess.run(init);
+ var step = 0;
+
+
+ while (step < num_steps)
+ {
+ var result = sess.run(
+ new ITensorOrOperation[] { train_op, gs, loss },
+ new FeedItem(features, X),
+ new FeedItem(labels, Y));
+ loss_value = result[2];
+ step = result[1];
+ if (step % 1000 == 0)
+ Console.WriteLine($"Step {step} loss: {loss_value}");
+ }
+ Console.WriteLine($"Final loss: {loss_value}");
+ });
+
+ return loss_value;
+ }
+ }
+}
diff --git a/test/KerasNET.Example/Keras.Example.csproj b/test/KerasNET.Example/Keras.Example.csproj
new file mode 100644
index 00000000..44bf52a4
--- /dev/null
+++ b/test/KerasNET.Example/Keras.Example.csproj
@@ -0,0 +1,22 @@
+
+
+
+ Exe
+ netcoreapp2.2
+ false
+ Keras.Example.Program
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/test/KerasNET.Example/Program.cs b/test/KerasNET.Example/Program.cs
new file mode 100644
index 00000000..2fbf288c
--- /dev/null
+++ b/test/KerasNET.Example/Program.cs
@@ -0,0 +1,67 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.Threading.Tasks;
+using Tensorflow;
+using static Tensorflow.Python;
+using static Keras.Keras;
+using Keras.Layers;
+using Keras;
+using NumSharp;
+
+namespace Keras.Example
+{
+ class Program
+ {
+ static void Main(string[] args)
+ {
+ Console.WriteLine("================================== Keras ==================================");
+
+ #region data
+ var batch_size = 1000;
+ var (X, Y) = XOR(batch_size);
+ //var (X, Y, batch_size) = (np.array(new float[,]{{1, 0 },{1, 1 },{0, 0 },{0, 1 }}), np.array(new int[] { 0, 1, 1, 0 }), 4);
+ #endregion
+
+ #region features
+ var (features, labels) = (new Tensor(X), new Tensor(Y));
+ var num_steps = 10000;
+ #endregion
+
+ #region model
+ var m = new Model();
+
+ //m.Add(new Dense(8, name: "Hidden", activation: tf.nn.relu())).Add(new Dense(1, name:"Output"));
+
+ m.Add(
+ new ILayer[] {
+ new Dense(8, name: "Hidden_1", activation: tf.nn.relu()),
+ new Dense(1, name: "Output")
+ });
+
+ m.train(num_steps, (X, Y));
+ #endregion
+
+ Console.ReadKey();
+ }
+ static (NDArray, NDArray) XOR(int samples)
+ {
+ var X = new List();
+ var Y = new List();
+ var r = new Random();
+ for (int i = 0; i < samples; i++)
+ {
+ var x1 = (float)r.Next(0, 2);
+ var x2 = (float)r.Next(0, 2);
+ var y = 0.0f;
+ if (x1 == x2)
+ y = 1.0f;
+ X.Add(new float[] { x1, x2 });
+ Y.Add(y);
+ }
+
+ return (np.array(X.ToArray()), np.array(Y.ToArray()));
+ }
+ }
+}
diff --git a/test/KerasNET.Example/packages.config b/test/KerasNET.Example/packages.config
new file mode 100644
index 00000000..e7c17277
--- /dev/null
+++ b/test/KerasNET.Example/packages.config
@@ -0,0 +1,10 @@
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/test/KerasNET.Test/BaseTests.cs b/test/KerasNET.Test/BaseTests.cs
new file mode 100644
index 00000000..6a716c5f
--- /dev/null
+++ b/test/KerasNET.Test/BaseTests.cs
@@ -0,0 +1,28 @@
+using System;
+using Tensorflow;
+using Keras;
+using Keras.Layers;
+using NumSharp;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+namespace Keras.Test
+{
+ [TestClass]
+ public class BaseTests
+ {
+ [TestMethod]
+ public void Dense_Tensor_ShapeTest()
+ {
+ var dense_1 = new Dense(1, name: "dense_1", activation: tf.nn.relu());
+ var input = new Tensor(np.array(new int[] { 3 }));
+ dense_1.__build__(input.getShape());
+ var outputShape = dense_1.output_shape(input.getShape());
+ var a = (int[])(outputShape.Dimensions);
+ var b = (int[])(new int[] { 1 });
+ var _a = np.array(a);
+ var _b = np.array(b);
+
+ Assert.IsTrue(np.array_equal(_a, _b));
+ }
+ }
+}
diff --git a/test/KerasNET.Test/Keras.UnitTest.csproj b/test/KerasNET.Test/Keras.UnitTest.csproj
new file mode 100644
index 00000000..89a5425c
--- /dev/null
+++ b/test/KerasNET.Test/Keras.UnitTest.csproj
@@ -0,0 +1,40 @@
+
+
+
+ netcoreapp2.2
+
+ false
+
+ Keras.UnitTest
+
+ Keras.UnitTest
+
+
+
+ Exe
+
+
+
+
+
+ DEBUG;TRACE
+ true
+
+
+
+ true
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/test/KerasNET.Test/packages.config b/test/KerasNET.Test/packages.config
new file mode 100644
index 00000000..7e0fea67
--- /dev/null
+++ b/test/KerasNET.Test/packages.config
@@ -0,0 +1,11 @@
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file