From 71e1fe6299a97f7ba7c91a2b40839ff43734197b Mon Sep 17 00:00:00 2001 From: Esther2013 Date: Thu, 10 Jan 2019 06:31:15 -0600 Subject: [PATCH] RefVariable, variable_scope --- src/TensorFlowNET.Core/Eager/Context.cs | 2 +- src/TensorFlowNET.Core/Graphs/Graph.cs | 5 ++ .../Operations/OpDefLibrary.cs | 2 +- .../Variables/RefVariable.cs | 56 ++++++++++++++++--- .../Variables/VariableV1.cs | 8 ++- .../Variables/gen_state_ops.py.cs | 35 ++++++++++++ .../Variables/variable_scope.py.cs | 4 +- .../Variables/variables.py.cs | 2 +- .../{ops.GraphKeys.py.cs => ops.GraphKeys.cs} | 8 ++- src/TensorFlowNET.Core/ops.name_scope.cs | 44 +++++++++++++++ src/TensorFlowNET.Core/ops.py.cs | 35 +++++++----- src/TensorFlowNET.Core/tf.cs | 4 -- 12 files changed, 173 insertions(+), 32 deletions(-) create mode 100644 src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs rename src/TensorFlowNET.Core/{ops.GraphKeys.py.cs => ops.GraphKeys.cs} (73%) create mode 100644 src/TensorFlowNET.Core/ops.name_scope.cs diff --git a/src/TensorFlowNET.Core/Eager/Context.cs b/src/TensorFlowNET.Core/Eager/Context.cs index 3d9c875d..a5bef053 100644 --- a/src/TensorFlowNET.Core/Eager/Context.cs +++ b/src/TensorFlowNET.Core/Eager/Context.cs @@ -2,7 +2,7 @@ using System.Collections.Generic; using System.Text; -namespace Tensorflow.Eager +namespace Tensorflow { public class Context { diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index b36480db..82516ee9 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -152,6 +152,11 @@ namespace Tensorflow return false; } + public string get_name_scope() + { + return _name_stack; + } + public string name_scope(string name) { string new_stack = ""; diff --git a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs index 7c4d0c87..121d4163 100644 --- a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs +++ b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs @@ -10,7 +10,7 @@ namespace Tensorflow { public class OpDefLibrary { - public unsafe Operation _apply_op_helper(string op_type_name, string name = "", Dictionary keywords = null) + public Operation _apply_op_helper(string op_type_name, string name = "", Dictionary keywords = null) { var g = ops.get_default_graph(); var op_def = g.GetOpDef(op_type_name); diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.cs b/src/TensorFlowNET.Core/Variables/RefVariable.cs index 6f129767..af861bbf 100644 --- a/src/TensorFlowNET.Core/Variables/RefVariable.cs +++ b/src/TensorFlowNET.Core/Variables/RefVariable.cs @@ -8,22 +8,62 @@ namespace Tensorflow { public bool _in_graph_mode = true; public Tensor _initial_value; + public string _graph_key; + public bool _trainable; + public Tensor _variable; - public RefVariable(object initial_value, + public RefVariable(object initial_value, + bool trainable = true, + List collections = null, + bool validate_shape = true, + string caching_device = "", string name = "", - TF_DataType trainable = TF_DataType.DtInvalid, - bool validate_shape = true) : - base(initial_value, name, trainable, validate_shape) + TF_DataType dtype = TF_DataType.DtInvalid) : + base(initial_value, trainable, collections, validate_shape, caching_device, name, dtype) { - _init_from_args(initial_value, name, trainable); + _init_from_args(initial_value, trainable, collections, validate_shape, caching_device, name, dtype); } private void _init_from_args(object initial_value, + bool trainable = true, + List collections = null, + bool validate_shape = true, + string caching_device = "", string name = "", - TF_DataType trainable = TF_DataType.DtInvalid) + TF_DataType dtype = TF_DataType.DtInvalid) { - name = ops.name_scope("", "Variable", initial_value); - _initial_value = ops.convert_to_tensor(initial_value, name: "initial_value"); + if (initial_value is null) + throw new ValueError("initial_value must be specified."); + + var init_from_fn = false; + + if(collections == null) + { + collections = new List { ops.GraphKeys.GLOBAL_VARIABLES }; + } + + // Store the graph key so optimizers know how to only retrieve variables from + // this graph. + _graph_key = ops.get_default_graph()._graph_key; + + _trainable = trainable; + if (!collections.Contains(ops.GraphKeys.TRAINABLE_VARIABLES)) + collections.Add(ops.GraphKeys.TRAINABLE_VARIABLES); + + ops.init_scope(); + name = new ops.name_scope(name, "Variable", init_from_fn ? new List() : new List { initial_value }); + if (init_from_fn) + { + + } + else + { + _initial_value = ops.convert_to_tensor(initial_value, name: "initial_value"); + } + + var shape = _initial_value.shape; + dtype = _initial_value.dtype; + _variable = gen_state_ops.variable_v2(shape, dtype, name); } } } diff --git a/src/TensorFlowNET.Core/Variables/VariableV1.cs b/src/TensorFlowNET.Core/Variables/VariableV1.cs index efd866fd..7d310f61 100644 --- a/src/TensorFlowNET.Core/Variables/VariableV1.cs +++ b/src/TensorFlowNET.Core/Variables/VariableV1.cs @@ -16,7 +16,13 @@ namespace Tensorflow /// public class VariableV1 { - public VariableV1(object initial_value, string name = "", TF_DataType trainable = TF_DataType.DtInvalid, bool validate_shape = true) + public VariableV1(object initial_value, + bool trainable = true, + List collections = null, + bool validate_shape = true, + string caching_device = "", + string name = "", + TF_DataType dtype = TF_DataType.DtInvalid) { } diff --git a/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs b/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs new file mode 100644 index 00000000..8843c475 --- /dev/null +++ b/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs @@ -0,0 +1,35 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public class gen_state_ops + { + public static OpDefLibrary _op_def_lib = new OpDefLibrary(); + + /// + /// Holds state in the form of a tensor that persists across steps. + /// Outputs a ref to the tensor state so it may be read or modified. + /// + /// The shape of the variable tensor. + /// The type of elements in the variable tensor. + /// + /// + /// + /// + public static Tensor variable_v2(long[] shape, TF_DataType dtype, string name = "", string container = "", string shared_name = "") + { + var keywords = new Dictionary(); + keywords.Add("dtype", dtype); + keywords.Add("shape", shape); + + var _op = _op_def_lib._apply_op_helper("VariableV2", name: name, keywords: keywords); + + var _result = _op.outputs; + var _inputs_flat = _op.inputs; + + return new Tensor(_op, 0, dtype); + } + } +} diff --git a/src/TensorFlowNET.Core/Variables/variable_scope.py.cs b/src/TensorFlowNET.Core/Variables/variable_scope.py.cs index b794c0f1..2c05581a 100644 --- a/src/TensorFlowNET.Core/Variables/variable_scope.py.cs +++ b/src/TensorFlowNET.Core/Variables/variable_scope.py.cs @@ -26,7 +26,9 @@ namespace Tensorflow } else { - return new RefVariable(initial_value); + return new RefVariable(initial_value, + name: name, + dtype: dtype); } } diff --git a/src/TensorFlowNET.Core/Variables/variables.py.cs b/src/TensorFlowNET.Core/Variables/variables.py.cs index 9a2602b2..1e9c426b 100644 --- a/src/TensorFlowNET.Core/Variables/variables.py.cs +++ b/src/TensorFlowNET.Core/Variables/variables.py.cs @@ -12,7 +12,7 @@ namespace Tensorflow /// public static object trainable_variables() { - return ops.get_collection(ops.GraphKey.TRAINABLE_VARIABLES); + return ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES); } } } diff --git a/src/TensorFlowNET.Core/ops.GraphKeys.py.cs b/src/TensorFlowNET.Core/ops.GraphKeys.cs similarity index 73% rename from src/TensorFlowNET.Core/ops.GraphKeys.py.cs rename to src/TensorFlowNET.Core/ops.GraphKeys.cs index a7f03cf9..cfc74aff 100644 --- a/src/TensorFlowNET.Core/ops.GraphKeys.py.cs +++ b/src/TensorFlowNET.Core/ops.GraphKeys.cs @@ -15,12 +15,18 @@ namespace Tensorflow /// specified, but it is also possible to pass an explicit list of /// variables. /// - public static class GraphKey + public static class GraphKeys { /// /// the subset of `Variable` objects that will be trained by an optimizer. /// public static string TRAINABLE_VARIABLES = "trainable_variables"; + + /// + /// Key to collect Variable objects that are global (shared across machines). + /// Default collection for all variables, except local ones. + /// + public static string GLOBAL_VARIABLES = "variables"; } } } diff --git a/src/TensorFlowNET.Core/ops.name_scope.cs b/src/TensorFlowNET.Core/ops.name_scope.cs new file mode 100644 index 00000000..fef5c283 --- /dev/null +++ b/src/TensorFlowNET.Core/ops.name_scope.cs @@ -0,0 +1,44 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public partial class ops + { + public class name_scope + { + public string _name; + public string _default_name; + public object _values; + public Context _ctx; + public string _name_scope; + + public name_scope(string name, string default_name, List values) + { + _name = name; + _default_name = default_name; + _values = values; + _ctx = new Context(); + + _name_scope = __enter__(); + } + + public string __enter__() + { + if (String.IsNullOrEmpty(_name)) + { + _name = _default_name; + } + + var g = get_default_graph(); + return g.name_scope(_name); + } + + public static implicit operator string(name_scope ns) + { + return ns._name_scope; + } + } + } +} diff --git a/src/TensorFlowNET.Core/ops.py.cs b/src/TensorFlowNET.Core/ops.py.cs index a0dfc52a..4a1a6bcc 100644 --- a/src/TensorFlowNET.Core/ops.py.cs +++ b/src/TensorFlowNET.Core/ops.py.cs @@ -97,20 +97,6 @@ namespace Tensorflow return node_def; } - public static string name_scope(string name, string default_name = "", object values = null) - { - string _name = ""; - if (String.IsNullOrEmpty(name)) - { - _name = default_name; - } - - var g = get_default_graph(); - var _name_scope = g.name_scope(_name); - - return _name_scope; - } - public static string _name_from_scope_name(string name) { if (name.EndsWith("/")) @@ -123,6 +109,27 @@ namespace Tensorflow } } + /// + /// A context manager that lifts ops out of control-flow scopes and function-building graphs. + /// + /// + public static void init_scope() + { + // Retrieve the active name scope: entering an `init_scope` preserves + // the name scope of the current context. + var default_graph = get_default_graph(); + var scope = default_graph.get_name_scope(); + if (!String.IsNullOrEmpty(scope) && !scope.EndsWith("/")) + // Names that end with trailing slashes are treated by `name_scope` as + // absolute. + scope += "/"; + // inner_device_stack = default_graph._device_function_stack + // var outer_context = default_graph.as_default; + + var outer_graph = get_default_graph(); + // outer_device_stack = None + } + public static int uid() { return 1; diff --git a/src/TensorFlowNET.Core/tf.cs b/src/TensorFlowNET.Core/tf.cs index 8c4d5611..3e21d929 100644 --- a/src/TensorFlowNET.Core/tf.cs +++ b/src/TensorFlowNET.Core/tf.cs @@ -1,10 +1,6 @@ using System; using System.Collections.Generic; -using System.Runtime.InteropServices; using System.Text; -using TF_DataType = Tensorflow.DataType; -using attr_value_pb2 = Tensorflow; -using Tensorflow.Eager; namespace Tensorflow {