diff --git a/TensorFlow.NET.sln b/TensorFlow.NET.sln index b96f8203..12b93519 100644 --- a/TensorFlow.NET.sln +++ b/TensorFlow.NET.sln @@ -1,7 +1,7 @@  Microsoft Visual Studio Solution File, Format Version 12.00 -# Visual Studio Version 16 -VisualStudioVersion = 16.0.28803.156 +# Visual Studio 15 +VisualStudioVersion = 15.0.28307.168 MinimumVisualStudioVersion = 10.0.40219.1 Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.UnitTest", "test\TensorFlowNET.UnitTest\TensorFlowNET.UnitTest.csproj", "{029A8CF1-CF95-4DCB-98AA-9D3D96A83B3E}" EndProject @@ -9,6 +9,12 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Examples", "t EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Core", "src\TensorFlowNET.Core\TensorFlowNET.Core.csproj", "{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}" EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Keras.Core", "src\KerasNET.Core\Keras.Core.csproj", "{902E188F-A953-43B4-9991-72BAB1697BC3}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Keras.Example", "test\KerasNET.Example\Keras.Example.csproj", "{17E1AC16-9E0E-4545-905A-E92C6300C7AF}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Keras.UnitTest", "test\KerasNET.Test\Keras.UnitTest.csproj", "{A5839A45-A117-4BEA-898B-DE1ED6E0D58F}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -27,6 +33,18 @@ Global {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug|Any CPU.Build.0 = Debug|Any CPU {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|Any CPU.ActiveCfg = Release|Any CPU {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|Any CPU.Build.0 = Release|Any CPU + {902E188F-A953-43B4-9991-72BAB1697BC3}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {902E188F-A953-43B4-9991-72BAB1697BC3}.Debug|Any CPU.Build.0 = Debug|Any CPU + {902E188F-A953-43B4-9991-72BAB1697BC3}.Release|Any CPU.ActiveCfg = Release|Any CPU + {902E188F-A953-43B4-9991-72BAB1697BC3}.Release|Any CPU.Build.0 = Release|Any CPU + {17E1AC16-9E0E-4545-905A-E92C6300C7AF}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {17E1AC16-9E0E-4545-905A-E92C6300C7AF}.Debug|Any CPU.Build.0 = Debug|Any CPU + {17E1AC16-9E0E-4545-905A-E92C6300C7AF}.Release|Any CPU.ActiveCfg = Release|Any CPU + {17E1AC16-9E0E-4545-905A-E92C6300C7AF}.Release|Any CPU.Build.0 = Release|Any CPU + {A5839A45-A117-4BEA-898B-DE1ED6E0D58F}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {A5839A45-A117-4BEA-898B-DE1ED6E0D58F}.Debug|Any CPU.Build.0 = Debug|Any CPU + {A5839A45-A117-4BEA-898B-DE1ED6E0D58F}.Release|Any CPU.ActiveCfg = Release|Any CPU + {A5839A45-A117-4BEA-898B-DE1ED6E0D58F}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE diff --git a/src/KerasNET.Core/Core.cs b/src/KerasNET.Core/Core.cs new file mode 100644 index 00000000..8adae938 --- /dev/null +++ b/src/KerasNET.Core/Core.cs @@ -0,0 +1,15 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow; + +namespace Keras +{ + public static class Keras + { + public static Tensor create_tensor(int[] shape, float mean = 0, float stddev = 1, TF_DataType dtype = TF_DataType.TF_FLOAT, int? seed = null, string name = null) + { + return tf.truncated_normal(shape: shape, mean: mean, stddev: stddev, dtype: dtype, seed: seed, name: name); + } + } +} diff --git a/src/KerasNET.Core/IInitializer.cs b/src/KerasNET.Core/IInitializer.cs new file mode 100644 index 00000000..53cb9112 --- /dev/null +++ b/src/KerasNET.Core/IInitializer.cs @@ -0,0 +1,10 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Keras +{ + interface IInitializer + { + } +} diff --git a/src/KerasNET.Core/Initializer/BaseInitializer.cs b/src/KerasNET.Core/Initializer/BaseInitializer.cs new file mode 100644 index 00000000..84a420a7 --- /dev/null +++ b/src/KerasNET.Core/Initializer/BaseInitializer.cs @@ -0,0 +1,13 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow; +using Tensorflow.Layers; + +namespace Keras.Initializer +{ + class BaseInitializer : IInitializer + { + public int seed; + } +} diff --git a/src/KerasNET.Core/Keras.Core.csproj b/src/KerasNET.Core/Keras.Core.csproj new file mode 100644 index 00000000..3a7cfc6b --- /dev/null +++ b/src/KerasNET.Core/Keras.Core.csproj @@ -0,0 +1,13 @@ + + + + netstandard2.0 + Keras + Keras + + + + + + + diff --git a/src/KerasNET.Core/Layers/Dense.cs b/src/KerasNET.Core/Layers/Dense.cs new file mode 100644 index 00000000..66569882 --- /dev/null +++ b/src/KerasNET.Core/Layers/Dense.cs @@ -0,0 +1,57 @@ +using System; +using System.Collections.Generic; +using System.Text; +using System.Linq; +using Tensorflow; +using static Keras.Keras; +using Keras; +using NumSharp; +using Tensorflow.Operations.Activation; + +namespace Keras.Layers +{ + public class Dense : ILayer + { + RefVariable W; + int units; + TensorShape WShape; + string name; + IActivation activation; + + public Dense(int units, string name = null, IActivation activation = null) + { + this.activation = activation; + this.units = units; + this.name = (string.IsNullOrEmpty(name) || string.IsNullOrWhiteSpace(name))?this.GetType().Name + "_" + this.GetType().GUID:name; + } + public ILayer __build__(TensorShape input_shape, int seed = 1, float stddev = -1f) + { + Console.WriteLine("Building Layer \"" + name + "\" ..."); + if (stddev == -1) + stddev = (float)(1 / Math.Sqrt(2)); + var dim = input_shape.Dimensions; + var input_dim = dim[dim.Length - 1]; + W = tf.Variable(create_tensor(new int[] { input_dim, units }, seed: seed, stddev: (float)stddev)); + WShape = new TensorShape(W.shape); + return this; + } + public Tensor __call__(Tensor x) + { + var dot = tf.matmul(x, W); + if (this.activation != null) + dot = activation.Activate(dot); + Console.WriteLine("Calling Layer \"" + name + "(" + np.array(dot.getShape().Dimensions).ToString() + ")\" ..."); + return dot; + } + public TensorShape __shape__() + { + return WShape; + } + public TensorShape output_shape(TensorShape input_shape) + { + var output_shape = input_shape.Dimensions; + output_shape[output_shape.Length - 1] = units; + return new TensorShape(output_shape); + } + } +} diff --git a/src/KerasNET.Core/Layers/ILayer.cs b/src/KerasNET.Core/Layers/ILayer.cs new file mode 100644 index 00000000..2c033eef --- /dev/null +++ b/src/KerasNET.Core/Layers/ILayer.cs @@ -0,0 +1,16 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow; +using NumSharp; + +namespace Keras.Layers +{ + public interface ILayer + { + TensorShape __shape__(); + ILayer __build__(TensorShape input_shape, int seed = 1, float stddev = -1f); + Tensor __call__(Tensor x); + TensorShape output_shape(TensorShape input_shape); + } +} diff --git a/src/KerasNET.Core/Model.cs b/src/KerasNET.Core/Model.cs new file mode 100644 index 00000000..d1d889fc --- /dev/null +++ b/src/KerasNET.Core/Model.cs @@ -0,0 +1,127 @@ +using Keras.Layers; +using NumSharp; +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow; +using static Keras.Keras; +using static Tensorflow.Python; + +namespace Keras +{ + public class Model + { + public Tensor Flow; + List layer_stack; + + public TensorShape InputShape; + + public Model() + { + layer_stack = new List(); + } + public Model Add(ILayer layer) + { + layer_stack.Add(layer); + return this; + } + public Model Add(IEnumerable layers) + { + layer_stack.AddRange(layers); + return this; + } + public Tensor getFlow() + { + try + { + return Flow; + } + catch (Exception ex) + { + return null; + } + } + public (Operation, Tensor, Tensor) make_graph(Tensor features, Tensor labels) + { + + // TODO : Creating Loss Functions And Optimizers..... + + #region Model Layers Graph + /* + var stddev = 1 / Math.Sqrt(2); + + var d1 = new Dense(num_hidden); + d1.__build__(features.getShape()); + var hidden_activations = tf.nn.relu(d1.__call__(features)); + + var d1_output = d1.output_shape(features.getShape()); + + + var d2 = new Dense(1); + d2.__build__(d1.output_shape(features.getShape()), seed: 17, stddev: (float)(1/ Math.Sqrt(num_hidden))); + var logits = d2.__call__(hidden_activations); + var predictions = tf.sigmoid(tf.squeeze(logits)); + */ + #endregion + + #region Model Graph Form Layer Stack + var flow_shape = features.getShape(); + Flow = features; + for (int i = 0; i < layer_stack.Count; i++) + { + layer_stack[i].__build__(flow_shape); + flow_shape = layer_stack[i].output_shape(flow_shape); + Flow = layer_stack[i].__call__(Flow); + } + var predictions = tf.sigmoid(tf.squeeze(Flow)); + + #endregion + + #region loss and optimizer + var loss = tf.reduce_mean(tf.square(predictions - tf.cast(labels, tf.float32)), name: "loss"); + + var gs = tf.Variable(0, trainable: false, name: "global_step"); + var train_op = tf.train.GradientDescentOptimizer(0.2f).minimize(loss, global_step: gs); + #endregion + + return (train_op, loss, gs); + } + public float train(int num_steps, (NDArray, NDArray) training_dataset) + { + var (X, Y) = training_dataset; + var x_shape = X.shape; + var batch_size = x_shape[0]; + var graph = tf.Graph().as_default(); + + var features = tf.placeholder(tf.float32, new TensorShape(batch_size, 2)); + var labels = tf.placeholder(tf.float32, new TensorShape(batch_size)); + + var (train_op, loss, gs) = this.make_graph(features, labels); + + var init = tf.global_variables_initializer(); + + float loss_value = 0; + with(tf.Session(graph), sess => + { + sess.run(init); + var step = 0; + + + while (step < num_steps) + { + var result = sess.run( + new ITensorOrOperation[] { train_op, gs, loss }, + new FeedItem(features, X), + new FeedItem(labels, Y)); + loss_value = result[2]; + step = result[1]; + if (step % 1000 == 0) + Console.WriteLine($"Step {step} loss: {loss_value}"); + } + Console.WriteLine($"Final loss: {loss_value}"); + }); + + return loss_value; + } + } +} diff --git a/test/KerasNET.Example/Keras.Example.csproj b/test/KerasNET.Example/Keras.Example.csproj new file mode 100644 index 00000000..44bf52a4 --- /dev/null +++ b/test/KerasNET.Example/Keras.Example.csproj @@ -0,0 +1,22 @@ + + + + Exe + netcoreapp2.2 + false + Keras.Example.Program + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/test/KerasNET.Example/Program.cs b/test/KerasNET.Example/Program.cs new file mode 100644 index 00000000..2fbf288c --- /dev/null +++ b/test/KerasNET.Example/Program.cs @@ -0,0 +1,67 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Tensorflow; +using static Tensorflow.Python; +using static Keras.Keras; +using Keras.Layers; +using Keras; +using NumSharp; + +namespace Keras.Example +{ + class Program + { + static void Main(string[] args) + { + Console.WriteLine("================================== Keras =================================="); + + #region data + var batch_size = 1000; + var (X, Y) = XOR(batch_size); + //var (X, Y, batch_size) = (np.array(new float[,]{{1, 0 },{1, 1 },{0, 0 },{0, 1 }}), np.array(new int[] { 0, 1, 1, 0 }), 4); + #endregion + + #region features + var (features, labels) = (new Tensor(X), new Tensor(Y)); + var num_steps = 10000; + #endregion + + #region model + var m = new Model(); + + //m.Add(new Dense(8, name: "Hidden", activation: tf.nn.relu())).Add(new Dense(1, name:"Output")); + + m.Add( + new ILayer[] { + new Dense(8, name: "Hidden_1", activation: tf.nn.relu()), + new Dense(1, name: "Output") + }); + + m.train(num_steps, (X, Y)); + #endregion + + Console.ReadKey(); + } + static (NDArray, NDArray) XOR(int samples) + { + var X = new List(); + var Y = new List(); + var r = new Random(); + for (int i = 0; i < samples; i++) + { + var x1 = (float)r.Next(0, 2); + var x2 = (float)r.Next(0, 2); + var y = 0.0f; + if (x1 == x2) + y = 1.0f; + X.Add(new float[] { x1, x2 }); + Y.Add(y); + } + + return (np.array(X.ToArray()), np.array(Y.ToArray())); + } + } +} diff --git a/test/KerasNET.Example/packages.config b/test/KerasNET.Example/packages.config new file mode 100644 index 00000000..e7c17277 --- /dev/null +++ b/test/KerasNET.Example/packages.config @@ -0,0 +1,10 @@ + + + + + + + + + + \ No newline at end of file diff --git a/test/KerasNET.Test/BaseTests.cs b/test/KerasNET.Test/BaseTests.cs new file mode 100644 index 00000000..6a716c5f --- /dev/null +++ b/test/KerasNET.Test/BaseTests.cs @@ -0,0 +1,28 @@ +using System; +using Tensorflow; +using Keras; +using Keras.Layers; +using NumSharp; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace Keras.Test +{ + [TestClass] + public class BaseTests + { + [TestMethod] + public void Dense_Tensor_ShapeTest() + { + var dense_1 = new Dense(1, name: "dense_1", activation: tf.nn.relu()); + var input = new Tensor(np.array(new int[] { 3 })); + dense_1.__build__(input.getShape()); + var outputShape = dense_1.output_shape(input.getShape()); + var a = (int[])(outputShape.Dimensions); + var b = (int[])(new int[] { 1 }); + var _a = np.array(a); + var _b = np.array(b); + + Assert.IsTrue(np.array_equal(_a, _b)); + } + } +} diff --git a/test/KerasNET.Test/Keras.UnitTest.csproj b/test/KerasNET.Test/Keras.UnitTest.csproj new file mode 100644 index 00000000..89a5425c --- /dev/null +++ b/test/KerasNET.Test/Keras.UnitTest.csproj @@ -0,0 +1,40 @@ + + + + netcoreapp2.2 + + false + + Keras.UnitTest + + Keras.UnitTest + + + + Exe + + + + + + DEBUG;TRACE + true + + + + true + + + + + + + + + + + + + + + diff --git a/test/KerasNET.Test/packages.config b/test/KerasNET.Test/packages.config new file mode 100644 index 00000000..7e0fea67 --- /dev/null +++ b/test/KerasNET.Test/packages.config @@ -0,0 +1,11 @@ + + + + + + + + + + + \ No newline at end of file