Browse Source

added make_variable overload

tags/v0.9
Arnav Das 6 years ago
parent
commit
d3724a93c4
3 changed files with 35 additions and 20 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/Layer.cs
  2. +19
    -4
      src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs
  3. +15
    -15
      test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs

+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/Layer.cs View File

@@ -190,7 +190,7 @@ namespace Tensorflow.Keras.Layers
var variable = _add_variable_with_custom_getter(name,
shape,
dtype: dtype,
getter: getter, // getter == null ? base_layer_utils.make_variable : getter,
getter: (getter == null) ? base_layer_utils.make_variable : getter,
overwrite: true,
initializer: initializer,
trainable: trainable.Value);


+ 19
- 4
src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs View File

@@ -8,6 +8,21 @@ namespace Tensorflow.Keras.Utils
{
public class base_layer_utils
{
/// <summary>
/// Adds a new variable to the layer.
/// </summary>
/// <param name="name"></param>
/// <param name="shape"></param>
/// <param name="dtype"></param>
/// <param name="initializer"></param>
/// <param name="trainable"></param>
/// <returns></returns>
public static RefVariable make_variable(string name,
int[] shape,
TF_DataType dtype = TF_DataType.TF_FLOAT,
IInitializer initializer = null,
bool trainable = true) => make_variable(name, shape, dtype, initializer, trainable, true);

/// <summary>
/// Adds a new variable to the layer.
/// </summary>
@@ -28,7 +43,7 @@ namespace Tensorflow.Keras.Utils

ops.init_scope();

Func<Tensor> init_val = ()=> initializer.call(new TensorShape(shape), dtype: dtype);
Func<Tensor> init_val = () => initializer.call(new TensorShape(shape), dtype: dtype);

var variable_dtype = dtype.as_base_dtype();
var v = tf.Variable(init_val);
@@ -44,13 +59,13 @@ namespace Tensorflow.Keras.Utils
public static string unique_layer_name(string name, Dictionary<(string, string), int> name_uid_map = null,
string[] avoid_names = null, string @namespace = "", bool zero_based = false)
{
if(name_uid_map == null)
if (name_uid_map == null)
name_uid_map = get_default_graph_uid_map();
if (avoid_names == null)
avoid_names = new string[0];

string proposed_name = null;
while(proposed_name == null || avoid_names.Contains(proposed_name))
while (proposed_name == null || avoid_names.Contains(proposed_name))
{
var name_key = (@namespace, name);
if (!name_uid_map.ContainsKey(name_key))
@@ -58,7 +73,7 @@ namespace Tensorflow.Keras.Utils

if (zero_based)
{
int number = name_uid_map[name_key];
int number = name_uid_map[name_key];
if (number > 0)
proposed_name = $"{name}_{number}";
else


+ 15
- 15
test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs View File

@@ -14,21 +14,21 @@ namespace TensorFlowNET.ExamplesTests
public void BasicOperations()
{
tf.Graph().as_default();
new BasicOperations() { Enabled = true }.Train();
new BasicOperations() { Enabled = true }.Run();
}
[TestMethod]
public void HelloWorld()
{
tf.Graph().as_default();
new HelloWorld() { Enabled = true }.Train();
new HelloWorld() { Enabled = true }.Run();
}
[TestMethod]
public void ImageRecognition()
{
tf.Graph().as_default();
new HelloWorld() { Enabled = true }.Train();
new HelloWorld() { Enabled = true }.Run();
}
[Ignore]
@@ -36,28 +36,28 @@ namespace TensorFlowNET.ExamplesTests
public void InceptionArchGoogLeNet()
{
tf.Graph().as_default();
new InceptionArchGoogLeNet() { Enabled = true }.Train();
new InceptionArchGoogLeNet() { Enabled = true }.Run();
}
[TestMethod]
public void KMeansClustering()
{
tf.Graph().as_default();
new KMeansClustering() { Enabled = true, IsImportingGraph = true, train_size = 500, validation_size = 100, test_size = 100, batch_size =100 }.Train();
new KMeansClustering() { Enabled = true, IsImportingGraph = true, train_size = 500, validation_size = 100, test_size = 100, batch_size =100 }.Run();
}
[TestMethod]
public void LinearRegression()
{
tf.Graph().as_default();
new LinearRegression() { Enabled = true }.Train();
new LinearRegression() { Enabled = true }.Run();
}
[TestMethod]
public void LogisticRegression()
{
tf.Graph().as_default();
new LogisticRegression() { Enabled = true, training_epochs=10, train_size = 500, validation_size = 100, test_size = 100 }.Train();
new LogisticRegression() { Enabled = true, training_epochs=10, train_size = 500, validation_size = 100, test_size = 100 }.Run();
}
[Ignore]
@@ -65,7 +65,7 @@ namespace TensorFlowNET.ExamplesTests
public void NaiveBayesClassifier()
{
tf.Graph().as_default();
new NaiveBayesClassifier() { Enabled = false }.Train();
new NaiveBayesClassifier() { Enabled = false }.Run();
}
[Ignore]
@@ -73,14 +73,14 @@ namespace TensorFlowNET.ExamplesTests
public void NamedEntityRecognition()
{
tf.Graph().as_default();
new NamedEntityRecognition() { Enabled = true }.Train();
new NamedEntityRecognition() { Enabled = true }.Run();
}
[TestMethod]
public void NearestNeighbor()
{
tf.Graph().as_default();
new NearestNeighbor() { Enabled = true, TrainSize = 500, ValidationSize = 100, TestSize = 100 }.Train();
new NearestNeighbor() { Enabled = true, TrainSize = 500, ValidationSize = 100, TestSize = 100 }.Run();
}
[Ignore]
@@ -88,7 +88,7 @@ namespace TensorFlowNET.ExamplesTests
public void TextClassificationTrain()
{
tf.Graph().as_default();
new TextClassificationTrain() { Enabled = true, DataLimit=100 }.Train();
new TextClassificationTrain() { Enabled = true, DataLimit=100 }.Run();
}
[Ignore]
@@ -96,21 +96,21 @@ namespace TensorFlowNET.ExamplesTests
public void TextClassificationWithMovieReviews()
{
tf.Graph().as_default();
new BinaryTextClassification() { Enabled = true }.Train();
new BinaryTextClassification() { Enabled = true }.Run();
}
[TestMethod]
public void NeuralNetXor()
{
tf.Graph().as_default();
Assert.IsTrue(new NeuralNetXor() { Enabled = true, IsImportingGraph = false }.Train());
Assert.IsTrue(new NeuralNetXor() { Enabled = true, IsImportingGraph = false }.Run());
}
[TestMethod]
public void NeuralNetXor_ImportedGraph()
{
tf.Graph().as_default();
Assert.IsTrue(new NeuralNetXor() { Enabled = true, IsImportingGraph = true }.Train());
Assert.IsTrue(new NeuralNetXor() { Enabled = true, IsImportingGraph = true }.Run());
}
@@ -118,7 +118,7 @@ namespace TensorFlowNET.ExamplesTests
public void ObjectDetection()
{
tf.Graph().as_default();
Assert.IsTrue(new ObjectDetection() { Enabled = true, IsImportingGraph = true }.Train());
Assert.IsTrue(new ObjectDetection() { Enabled = true, IsImportingGraph = true }.Run());
}
}
}

Loading…
Cancel
Save