From d3724a93c40a8ca438dbf15b520baf708f9b3740 Mon Sep 17 00:00:00 2001 From: Arnav Das Date: Wed, 5 Jun 2019 20:13:53 +0530 Subject: [PATCH] added make_variable overload --- src/TensorFlowNET.Core/Keras/Layers/Layer.cs | 2 +- .../Keras/Utils/base_layer_utils.cs | 23 +++++++++++--- .../ExamplesTests/ExamplesTest.cs | 30 +++++++++---------- 3 files changed, 35 insertions(+), 20 deletions(-) diff --git a/src/TensorFlowNET.Core/Keras/Layers/Layer.cs b/src/TensorFlowNET.Core/Keras/Layers/Layer.cs index 7f7f75a2..db089959 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Layer.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Layer.cs @@ -190,7 +190,7 @@ namespace Tensorflow.Keras.Layers var variable = _add_variable_with_custom_getter(name, shape, dtype: dtype, - getter: getter, // getter == null ? base_layer_utils.make_variable : getter, + getter: (getter == null) ? base_layer_utils.make_variable : getter, overwrite: true, initializer: initializer, trainable: trainable.Value); diff --git a/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs b/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs index db927089..22c2cfc5 100644 --- a/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs +++ b/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs @@ -8,6 +8,21 @@ namespace Tensorflow.Keras.Utils { public class base_layer_utils { + /// + /// Adds a new variable to the layer. + /// + /// + /// + /// + /// + /// + /// + public static RefVariable make_variable(string name, + int[] shape, + TF_DataType dtype = TF_DataType.TF_FLOAT, + IInitializer initializer = null, + bool trainable = true) => make_variable(name, shape, dtype, initializer, trainable, true); + /// /// Adds a new variable to the layer. /// @@ -28,7 +43,7 @@ namespace Tensorflow.Keras.Utils ops.init_scope(); - Func init_val = ()=> initializer.call(new TensorShape(shape), dtype: dtype); + Func init_val = () => initializer.call(new TensorShape(shape), dtype: dtype); var variable_dtype = dtype.as_base_dtype(); var v = tf.Variable(init_val); @@ -44,13 +59,13 @@ namespace Tensorflow.Keras.Utils public static string unique_layer_name(string name, Dictionary<(string, string), int> name_uid_map = null, string[] avoid_names = null, string @namespace = "", bool zero_based = false) { - if(name_uid_map == null) + if (name_uid_map == null) name_uid_map = get_default_graph_uid_map(); if (avoid_names == null) avoid_names = new string[0]; string proposed_name = null; - while(proposed_name == null || avoid_names.Contains(proposed_name)) + while (proposed_name == null || avoid_names.Contains(proposed_name)) { var name_key = (@namespace, name); if (!name_uid_map.ContainsKey(name_key)) @@ -58,7 +73,7 @@ namespace Tensorflow.Keras.Utils if (zero_based) { - int number = name_uid_map[name_key]; + int number = name_uid_map[name_key]; if (number > 0) proposed_name = $"{name}_{number}"; else diff --git a/test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs b/test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs index c5497692..214bb835 100644 --- a/test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs +++ b/test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs @@ -14,21 +14,21 @@ namespace TensorFlowNET.ExamplesTests public void BasicOperations() { tf.Graph().as_default(); - new BasicOperations() { Enabled = true }.Train(); + new BasicOperations() { Enabled = true }.Run(); } [TestMethod] public void HelloWorld() { tf.Graph().as_default(); - new HelloWorld() { Enabled = true }.Train(); + new HelloWorld() { Enabled = true }.Run(); } [TestMethod] public void ImageRecognition() { tf.Graph().as_default(); - new HelloWorld() { Enabled = true }.Train(); + new HelloWorld() { Enabled = true }.Run(); } [Ignore] @@ -36,28 +36,28 @@ namespace TensorFlowNET.ExamplesTests public void InceptionArchGoogLeNet() { tf.Graph().as_default(); - new InceptionArchGoogLeNet() { Enabled = true }.Train(); + new InceptionArchGoogLeNet() { Enabled = true }.Run(); } [TestMethod] public void KMeansClustering() { tf.Graph().as_default(); - new KMeansClustering() { Enabled = true, IsImportingGraph = true, train_size = 500, validation_size = 100, test_size = 100, batch_size =100 }.Train(); + new KMeansClustering() { Enabled = true, IsImportingGraph = true, train_size = 500, validation_size = 100, test_size = 100, batch_size =100 }.Run(); } [TestMethod] public void LinearRegression() { tf.Graph().as_default(); - new LinearRegression() { Enabled = true }.Train(); + new LinearRegression() { Enabled = true }.Run(); } [TestMethod] public void LogisticRegression() { tf.Graph().as_default(); - new LogisticRegression() { Enabled = true, training_epochs=10, train_size = 500, validation_size = 100, test_size = 100 }.Train(); + new LogisticRegression() { Enabled = true, training_epochs=10, train_size = 500, validation_size = 100, test_size = 100 }.Run(); } [Ignore] @@ -65,7 +65,7 @@ namespace TensorFlowNET.ExamplesTests public void NaiveBayesClassifier() { tf.Graph().as_default(); - new NaiveBayesClassifier() { Enabled = false }.Train(); + new NaiveBayesClassifier() { Enabled = false }.Run(); } [Ignore] @@ -73,14 +73,14 @@ namespace TensorFlowNET.ExamplesTests public void NamedEntityRecognition() { tf.Graph().as_default(); - new NamedEntityRecognition() { Enabled = true }.Train(); + new NamedEntityRecognition() { Enabled = true }.Run(); } [TestMethod] public void NearestNeighbor() { tf.Graph().as_default(); - new NearestNeighbor() { Enabled = true, TrainSize = 500, ValidationSize = 100, TestSize = 100 }.Train(); + new NearestNeighbor() { Enabled = true, TrainSize = 500, ValidationSize = 100, TestSize = 100 }.Run(); } [Ignore] @@ -88,7 +88,7 @@ namespace TensorFlowNET.ExamplesTests public void TextClassificationTrain() { tf.Graph().as_default(); - new TextClassificationTrain() { Enabled = true, DataLimit=100 }.Train(); + new TextClassificationTrain() { Enabled = true, DataLimit=100 }.Run(); } [Ignore] @@ -96,21 +96,21 @@ namespace TensorFlowNET.ExamplesTests public void TextClassificationWithMovieReviews() { tf.Graph().as_default(); - new BinaryTextClassification() { Enabled = true }.Train(); + new BinaryTextClassification() { Enabled = true }.Run(); } [TestMethod] public void NeuralNetXor() { tf.Graph().as_default(); - Assert.IsTrue(new NeuralNetXor() { Enabled = true, IsImportingGraph = false }.Train()); + Assert.IsTrue(new NeuralNetXor() { Enabled = true, IsImportingGraph = false }.Run()); } [TestMethod] public void NeuralNetXor_ImportedGraph() { tf.Graph().as_default(); - Assert.IsTrue(new NeuralNetXor() { Enabled = true, IsImportingGraph = true }.Train()); + Assert.IsTrue(new NeuralNetXor() { Enabled = true, IsImportingGraph = true }.Run()); } @@ -118,7 +118,7 @@ namespace TensorFlowNET.ExamplesTests public void ObjectDetection() { tf.Graph().as_default(); - Assert.IsTrue(new ObjectDetection() { Enabled = true, IsImportingGraph = true }.Train()); + Assert.IsTrue(new ObjectDetection() { Enabled = true, IsImportingGraph = true }.Run()); } } }