diff --git a/src/TensorFlowNET.Core/Keras/Layers/Layer.cs b/src/TensorFlowNET.Core/Keras/Layers/Layer.cs
index 7f7f75a2..db089959 100644
--- a/src/TensorFlowNET.Core/Keras/Layers/Layer.cs
+++ b/src/TensorFlowNET.Core/Keras/Layers/Layer.cs
@@ -190,7 +190,7 @@ namespace Tensorflow.Keras.Layers
var variable = _add_variable_with_custom_getter(name,
shape,
dtype: dtype,
- getter: getter, // getter == null ? base_layer_utils.make_variable : getter,
+ getter: (getter == null) ? base_layer_utils.make_variable : getter,
overwrite: true,
initializer: initializer,
trainable: trainable.Value);
diff --git a/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs b/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs
index db927089..22c2cfc5 100644
--- a/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs
+++ b/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs
@@ -8,6 +8,21 @@ namespace Tensorflow.Keras.Utils
{
public class base_layer_utils
{
+ ///
+ /// Adds a new variable to the layer.
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ public static RefVariable make_variable(string name,
+ int[] shape,
+ TF_DataType dtype = TF_DataType.TF_FLOAT,
+ IInitializer initializer = null,
+ bool trainable = true) => make_variable(name, shape, dtype, initializer, trainable, true);
+
///
/// Adds a new variable to the layer.
///
@@ -28,7 +43,7 @@ namespace Tensorflow.Keras.Utils
ops.init_scope();
- Func init_val = ()=> initializer.call(new TensorShape(shape), dtype: dtype);
+ Func init_val = () => initializer.call(new TensorShape(shape), dtype: dtype);
var variable_dtype = dtype.as_base_dtype();
var v = tf.Variable(init_val);
@@ -44,13 +59,13 @@ namespace Tensorflow.Keras.Utils
public static string unique_layer_name(string name, Dictionary<(string, string), int> name_uid_map = null,
string[] avoid_names = null, string @namespace = "", bool zero_based = false)
{
- if(name_uid_map == null)
+ if (name_uid_map == null)
name_uid_map = get_default_graph_uid_map();
if (avoid_names == null)
avoid_names = new string[0];
string proposed_name = null;
- while(proposed_name == null || avoid_names.Contains(proposed_name))
+ while (proposed_name == null || avoid_names.Contains(proposed_name))
{
var name_key = (@namespace, name);
if (!name_uid_map.ContainsKey(name_key))
@@ -58,7 +73,7 @@ namespace Tensorflow.Keras.Utils
if (zero_based)
{
- int number = name_uid_map[name_key];
+ int number = name_uid_map[name_key];
if (number > 0)
proposed_name = $"{name}_{number}";
else
diff --git a/test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs b/test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs
index c5497692..214bb835 100644
--- a/test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs
+++ b/test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs
@@ -14,21 +14,21 @@ namespace TensorFlowNET.ExamplesTests
public void BasicOperations()
{
tf.Graph().as_default();
- new BasicOperations() { Enabled = true }.Train();
+ new BasicOperations() { Enabled = true }.Run();
}
[TestMethod]
public void HelloWorld()
{
tf.Graph().as_default();
- new HelloWorld() { Enabled = true }.Train();
+ new HelloWorld() { Enabled = true }.Run();
}
[TestMethod]
public void ImageRecognition()
{
tf.Graph().as_default();
- new HelloWorld() { Enabled = true }.Train();
+ new HelloWorld() { Enabled = true }.Run();
}
[Ignore]
@@ -36,28 +36,28 @@ namespace TensorFlowNET.ExamplesTests
public void InceptionArchGoogLeNet()
{
tf.Graph().as_default();
- new InceptionArchGoogLeNet() { Enabled = true }.Train();
+ new InceptionArchGoogLeNet() { Enabled = true }.Run();
}
[TestMethod]
public void KMeansClustering()
{
tf.Graph().as_default();
- new KMeansClustering() { Enabled = true, IsImportingGraph = true, train_size = 500, validation_size = 100, test_size = 100, batch_size =100 }.Train();
+ new KMeansClustering() { Enabled = true, IsImportingGraph = true, train_size = 500, validation_size = 100, test_size = 100, batch_size =100 }.Run();
}
[TestMethod]
public void LinearRegression()
{
tf.Graph().as_default();
- new LinearRegression() { Enabled = true }.Train();
+ new LinearRegression() { Enabled = true }.Run();
}
[TestMethod]
public void LogisticRegression()
{
tf.Graph().as_default();
- new LogisticRegression() { Enabled = true, training_epochs=10, train_size = 500, validation_size = 100, test_size = 100 }.Train();
+ new LogisticRegression() { Enabled = true, training_epochs=10, train_size = 500, validation_size = 100, test_size = 100 }.Run();
}
[Ignore]
@@ -65,7 +65,7 @@ namespace TensorFlowNET.ExamplesTests
public void NaiveBayesClassifier()
{
tf.Graph().as_default();
- new NaiveBayesClassifier() { Enabled = false }.Train();
+ new NaiveBayesClassifier() { Enabled = false }.Run();
}
[Ignore]
@@ -73,14 +73,14 @@ namespace TensorFlowNET.ExamplesTests
public void NamedEntityRecognition()
{
tf.Graph().as_default();
- new NamedEntityRecognition() { Enabled = true }.Train();
+ new NamedEntityRecognition() { Enabled = true }.Run();
}
[TestMethod]
public void NearestNeighbor()
{
tf.Graph().as_default();
- new NearestNeighbor() { Enabled = true, TrainSize = 500, ValidationSize = 100, TestSize = 100 }.Train();
+ new NearestNeighbor() { Enabled = true, TrainSize = 500, ValidationSize = 100, TestSize = 100 }.Run();
}
[Ignore]
@@ -88,7 +88,7 @@ namespace TensorFlowNET.ExamplesTests
public void TextClassificationTrain()
{
tf.Graph().as_default();
- new TextClassificationTrain() { Enabled = true, DataLimit=100 }.Train();
+ new TextClassificationTrain() { Enabled = true, DataLimit=100 }.Run();
}
[Ignore]
@@ -96,21 +96,21 @@ namespace TensorFlowNET.ExamplesTests
public void TextClassificationWithMovieReviews()
{
tf.Graph().as_default();
- new BinaryTextClassification() { Enabled = true }.Train();
+ new BinaryTextClassification() { Enabled = true }.Run();
}
[TestMethod]
public void NeuralNetXor()
{
tf.Graph().as_default();
- Assert.IsTrue(new NeuralNetXor() { Enabled = true, IsImportingGraph = false }.Train());
+ Assert.IsTrue(new NeuralNetXor() { Enabled = true, IsImportingGraph = false }.Run());
}
[TestMethod]
public void NeuralNetXor_ImportedGraph()
{
tf.Graph().as_default();
- Assert.IsTrue(new NeuralNetXor() { Enabled = true, IsImportingGraph = true }.Train());
+ Assert.IsTrue(new NeuralNetXor() { Enabled = true, IsImportingGraph = true }.Run());
}
@@ -118,7 +118,7 @@ namespace TensorFlowNET.ExamplesTests
public void ObjectDetection()
{
tf.Graph().as_default();
- Assert.IsTrue(new ObjectDetection() { Enabled = true, IsImportingGraph = true }.Train());
+ Assert.IsTrue(new ObjectDetection() { Enabled = true, IsImportingGraph = true }.Run());
}
}
}