From 91608875a0612ec4d35547ea304d8d3eda8ac0a6 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 3 Mar 2019 18:37:08 -0600 Subject: [PATCH] VarCreation1 passed. create variable through variable_scope. --- README.md | 2 +- src/TensorFlowNET.Core/APIs/tf.init.cs | 2 +- .../Operations/gen_random_ops.py.cs | 26 +++++++++++++- .../Operations/random_ops.py.cs | 6 +++- src/TensorFlowNET.Core/Tensors/Tensor.cs | 5 +++ .../Variables/RefVariable.cs | 31 ++++++++++++---- .../Variables/_VariableStore.cs | 35 ++++++++++--------- src/TensorFlowNET.Core/Variables/state_ops.cs | 28 +++++++++++++++ .../LinearRegression.cs | 2 +- test/TensorFlowNET.UnitTest/VariableTest.cs | 6 +++- 10 files changed, 114 insertions(+), 29 deletions(-) create mode 100644 src/TensorFlowNET.Core/Variables/state_ops.cs diff --git a/README.md b/README.md index 42cd3ec4..fd92c400 100644 --- a/README.md +++ b/README.md @@ -66,7 +66,7 @@ using(var sess = tf.Session()) Read the docs & book [The Definitive Guide to Tensorflow.NET](https://tensorflownet.readthedocs.io/en/latest/FrontCover.html). -More examples: +### More examples: * [Hello World](test/TensorFlowNET.Examples/HelloWorld.cs) * [Basic Operations](test/TensorFlowNET.Examples/BasicOperations.cs) diff --git a/src/TensorFlowNET.Core/APIs/tf.init.cs b/src/TensorFlowNET.Core/APIs/tf.init.cs index 5b80313c..32b75807 100644 --- a/src/TensorFlowNET.Core/APIs/tf.init.cs +++ b/src/TensorFlowNET.Core/APIs/tf.init.cs @@ -7,7 +7,7 @@ namespace Tensorflow public static partial class tf { public static IInitializer zeros_initializer => new Zeros(); - public static IInitializer glorot_uniform => new GlorotUniform(); + public static IInitializer glorot_uniform_initializer => new GlorotUniform(); public static variable_scope variable_scope(string name_or_scope, string default_name = null, diff --git a/src/TensorFlowNET.Core/Operations/gen_random_ops.py.cs b/src/TensorFlowNET.Core/Operations/gen_random_ops.py.cs index 8cf40fd1..354177f9 100644 --- a/src/TensorFlowNET.Core/Operations/gen_random_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/gen_random_ops.py.cs @@ -24,10 +24,34 @@ namespace Tensorflow if (!seed2.HasValue) seed2 = 0; - var _op = _op_def_lib._apply_op_helper("RandomStandardNormal", name: name, + var _op = _op_def_lib._apply_op_helper("RandomStandardNormal", + name: name, args: new { shape, dtype, seed, seed2 }); return _op.outputs[0]; } + + /// + /// Outputs random values from a uniform distribution. + /// + /// + /// + /// + /// + /// + /// + public static Tensor random_uniform(Tensor shape, TF_DataType dtype, int? seed = 0, int? seed2 = 0, string name = null) + { + if (!seed.HasValue) + seed = 0; + if (!seed2.HasValue) + seed2 = 0; + + var _op = _op_def_lib._apply_op_helper("RandomUniform", + name: name, + args: new { shape, dtype, seed, seed2}); + + return _op.outputs[0]; + } } } diff --git a/src/TensorFlowNET.Core/Operations/random_ops.py.cs b/src/TensorFlowNET.Core/Operations/random_ops.py.cs index eae27c58..00f1aa00 100644 --- a/src/TensorFlowNET.Core/Operations/random_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/random_ops.py.cs @@ -56,7 +56,11 @@ namespace Tensorflow return with(new ops.name_scope(name, "random_uniform", new { shape, minval, maxval }), scope => { name = scope; - return null; + var tensorShape = _ShapeTensor(shape); + var minTensor = ops.convert_to_tensor(minval, dtype: dtype, name: "min"); + var maxTensor = ops.convert_to_tensor(maxval, dtype: dtype, name: "max"); + var rnd = gen_random_ops.random_uniform(tensorShape, dtype); + return math_ops.add(rnd * (maxTensor - minTensor), minTensor, name: name); }); } diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index e2dd55a5..76db062e 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -77,6 +77,11 @@ namespace Tensorflow return null; } + public TensorShape getShape() + { + return tensor_util.to_shape(shape); + } + /// /// number of dimensions /// 0 Scalar (magnitude only) diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.cs b/src/TensorFlowNET.Core/Variables/RefVariable.cs index f0e4e721..ea7069e4 100644 --- a/src/TensorFlowNET.Core/Variables/RefVariable.cs +++ b/src/TensorFlowNET.Core/Variables/RefVariable.cs @@ -1,4 +1,6 @@ -using System; +using Google.Protobuf; +using Google.Protobuf.Collections; +using System; using System.Collections.Generic; using System.Linq; using System.Text; @@ -99,7 +101,7 @@ namespace Tensorflow if (initial_value is null) throw new ValueError("initial_value must be specified."); - var init_from_fn = false; + var init_from_fn = initial_value.GetType().Name == "Func`1"; if(collections == null) { @@ -115,12 +117,27 @@ namespace Tensorflow collections.Add(ops.GraphKeys.TRAINABLE_VARIABLES); ops.init_scope(); - var values = init_from_fn ? new List() : new List { initial_value }; - Python.with(new ops.name_scope(name, "Variable", values), scope => + var values = init_from_fn ? new object[0] : new object[] { initial_value }; + with(new ops.name_scope(name, "Variable", values), scope => { + name = scope; if (init_from_fn) { - + // Use attr_scope and device(None) to simulate the behavior of + // colocate_with when the variable we want to colocate with doesn't + // yet exist. + string true_name = ops._name_from_scope_name(name); + var attr = new AttrValue + { + List = new AttrValue.Types.ListValue() + }; + attr.List.S.Add(ByteString.CopyFromUtf8($"loc:{true_name}")); + with(new ops.name_scope("Initializer"), scope2 => + { + _initial_value = (initial_value as Func)(); + _initial_value = ops.convert_to_tensor(_initial_value, name: "initial_value", dtype: dtype); + _variable = state_ops.variable_op_v2(_initial_value.shape, _initial_value.dtype.as_base_dtype(), name: name); + }); } // Or get the initial value from a Tensor or Python object. else @@ -135,7 +152,9 @@ namespace Tensorflow // Manually overrides the variable's shape with the initial value's. if (validate_shape) { - var initial_value_shape = _initial_value.shape; + var initial_value_shape = _initial_value.getShape(); + if (!initial_value_shape.is_fully_defined()) + throw new ValueError($"initial_value must have a shape specified: {_initial_value}"); } // If 'initial_value' makes use of other variables, make sure we don't diff --git a/src/TensorFlowNET.Core/Variables/_VariableStore.cs b/src/TensorFlowNET.Core/Variables/_VariableStore.cs index 873d33a6..5bd8c86d 100644 --- a/src/TensorFlowNET.Core/Variables/_VariableStore.cs +++ b/src/TensorFlowNET.Core/Variables/_VariableStore.cs @@ -74,7 +74,9 @@ namespace Tensorflow VariableSynchronization synchronization = VariableSynchronization.AUTO, VariableAggregation aggregation = VariableAggregation.NONE) { - bool initializing_from_value = false; + bool initializing_from_value = true; + if (use_resource == null) + use_resource = false; if (_vars.ContainsKey(name)) { @@ -86,15 +88,18 @@ namespace Tensorflow throw new NotImplementedException("_get_single_variable"); } - Tensor init_val = null; - + RefVariable v = null; // Create the tensor to initialize the variable with default value. if (initializer == null) { if (dtype.is_floating()) - initializer = tf.glorot_uniform; + { + initializer = tf.glorot_uniform_initializer; + initializing_from_value = false; + } } + // Create the variable. ops.init_scope(); { if (initializing_from_value) @@ -103,23 +108,19 @@ namespace Tensorflow } else { - init_val = initializer.call(shape, dtype); + Func init_val = () => initializer.call(shape, dtype); var variable_dtype = dtype.as_base_dtype(); + + v = variable_scope.default_variable_creator(init_val, + name: name, + trainable: trainable, + dtype: TF_DataType.DtInvalid, + validate_shape: validate_shape, + synchronization: synchronization, + aggregation: aggregation); } } - // Create the variable. - if (use_resource == null) - use_resource = false; - - var v = variable_scope.default_variable_creator(init_val, - name: name, - trainable: trainable, - dtype: TF_DataType.DtInvalid, - validate_shape: validate_shape, - synchronization: synchronization, - aggregation: aggregation); - _vars[name] = v; return v; diff --git a/src/TensorFlowNET.Core/Variables/state_ops.cs b/src/TensorFlowNET.Core/Variables/state_ops.cs new file mode 100644 index 00000000..0144a138 --- /dev/null +++ b/src/TensorFlowNET.Core/Variables/state_ops.cs @@ -0,0 +1,28 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public class state_ops + { + /// + /// Create a variable Operation. + /// + /// + /// + /// + /// + /// + /// + public static Tensor variable_op_v2(long[] shape, + TF_DataType dtype, + string name = "Variable", + string container = "", + string shared_name = "") => gen_state_ops.variable_v2(shape, + dtype, + name: name, + container: container, + shared_name: shared_name); + } +} diff --git a/test/TensorFlowNET.Examples/LinearRegression.cs b/test/TensorFlowNET.Examples/LinearRegression.cs index bac69b76..e1f81f5d 100644 --- a/test/TensorFlowNET.Examples/LinearRegression.cs +++ b/test/TensorFlowNET.Examples/LinearRegression.cs @@ -16,7 +16,7 @@ namespace TensorFlowNET.Examples // Parameters float learning_rate = 0.01f; - int training_epochs = 10000; + int training_epochs = 1000; int display_step = 50; public void Run() diff --git a/test/TensorFlowNET.UnitTest/VariableTest.cs b/test/TensorFlowNET.UnitTest/VariableTest.cs index 15773152..26351a84 100644 --- a/test/TensorFlowNET.UnitTest/VariableTest.cs +++ b/test/TensorFlowNET.UnitTest/VariableTest.cs @@ -30,14 +30,18 @@ namespace TensorFlowNET.UnitTest var mammal2 = tf.Variable("Tiger"); } + /// + /// https://www.tensorflow.org/api_docs/python/tf/variable_scope + /// [TestMethod] - public void SimpleScope() + public void VarCreation1() { with(tf.variable_scope("foo"), delegate { with(tf.variable_scope("bar"), delegate { var v = tf.get_variable("v", new TensorShape(1)); + Assert.AreEqual(v.name, "foo/bar/v:0"); }); }); }