From 8b32de72efd13894d57cd80208f85cf17940164b Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 12 Jan 2019 09:45:11 -0600 Subject: [PATCH] Fixed #111 --- .../Gradients/gradients_impl.py.cs | 7 + src/TensorFlowNET.Core/Graphs/Graph.cs | 1 - .../Operations/OpDefLibrary.cs | 161 +++++++++--------- .../TensorFlowNET.Core.csproj | 6 +- src/TensorFlowNET.Core/Tensors/TF_DataType.cs | 8 +- src/TensorFlowNET.Core/Tensors/Tensor.cs | 7 +- src/TensorFlowNET.Core/Train/Optimizer.cs | 11 +- .../Variables/RefVariable.cs | 11 ++ .../Variables/gen_state_ops.py.cs | 29 +++- src/TensorFlowNET.Core/ops.name_scope.cs | 14 +- .../TensorFlowNET.Examples.csproj | 2 +- .../TensorFlowNET.UnitTest.csproj | 2 +- 12 files changed, 166 insertions(+), 93 deletions(-) diff --git a/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs b/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs index 954f5b11..53288942 100644 --- a/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs +++ b/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs @@ -30,6 +30,10 @@ namespace Tensorflow if (src_graph == null) src_graph = ops.get_default_graph(); + // If src_graph is a _FuncGraph (i.e. a function body), gather it and all + // ancestor graphs. This is necessary for correctly handling captured values. + var curr_graph = src_graph; + var ys1 = _AsList(ys); var xs1 = _AsList(xs); List grad_ys1 = null; @@ -47,7 +51,10 @@ namespace Tensorflow string grad_scope = ""; using (var namescope = new ops.name_scope(name, "gradients", values: all)) + { grad_scope = namescope; + + } } private static List _AsList(object ys) diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index d2ed0c99..cb5784ce 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -173,7 +173,6 @@ namespace Tensorflow string new_stack = ""; - if (name.EndsWith("/")) new_stack = ops._name_from_scope_name(name); else diff --git a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs index a5ee3b23..b31be3c0 100644 --- a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs +++ b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs @@ -15,14 +15,15 @@ namespace Tensorflow var g = ops.get_default_graph(); var op_def = g.GetOpDef(op_type_name); + // Default name if not specified. if (String.IsNullOrEmpty(name)) - { name = op_type_name; - } - string scope = ""; - using (var namescope = new ops.name_scope(name)) - scope = namescope; + // Check for deprecation + if(op_def.Deprecation != null && op_def.Deprecation.Version > 0) + { + + } var default_type_attr_map = new Dictionary(); foreach (var attr_def in op_def.Attr) @@ -39,101 +40,107 @@ namespace Tensorflow var inputs = new List(); var input_types = new List(); - // Perform input type inference - foreach (var input_arg in op_def.InputArg) + string scope = ""; + using (var namescope = new ops.name_scope(name)) { - var input_name = input_arg.Name; - if (keywords[input_name] is double int_value) - { - keywords[input_name] = constant_op.Constant(int_value, input_name); - } + scope = namescope; - if (keywords[input_name] is Tensor value) + // Perform input type inference + foreach (var input_arg in op_def.InputArg) { - if (keywords.ContainsKey(input_name)) + var input_name = input_arg.Name; + if (keywords[input_name] is double int_value) { - inputs.Add(value); + keywords[input_name] = constant_op.Constant(int_value, input_name); } - if (!String.IsNullOrEmpty(input_arg.TypeAttr)) + if (keywords[input_name] is Tensor value) { - attrs[input_arg.TypeAttr] = value.dtype; + if (keywords.ContainsKey(input_name)) + { + inputs.Add(value); + } + + if (!String.IsNullOrEmpty(input_arg.TypeAttr)) + { + attrs[input_arg.TypeAttr] = value.dtype; + } + + if (input_arg.IsRef) + { + + } + else + { + input_types.Add(value.dtype); + } } + } - if (input_arg.IsRef) - { - - } - else + // Process remaining attrs + foreach (var attr in op_def.Attr) + { + if (keywords.ContainsKey(attr.Name)) { - input_types.Add(value.dtype); + attrs[attr.Name] = keywords[attr.Name]; } } - } - // Process remaining attrs - foreach (var attr in op_def.Attr) - { - if (keywords.ContainsKey(attr.Name)) + // Convert attr values to AttrValue protos. + var attr_protos = new Dictionary(); + foreach (var attr_def in op_def.Attr) { - attrs[attr.Name] = keywords[attr.Name]; - } - } + var key = attr_def.Name; + var value = attrs[key]; + var attr_value = new AttrValue(); - // Convert attr values to AttrValue protos. - var attr_protos = new Dictionary(); - foreach (var attr_def in op_def.Attr) - { - var key = attr_def.Name; - var value = attrs[key]; - var attr_value = new AttrValue(); - - switch (attr_def.Type) - { - case "string": - attr_value.S = Google.Protobuf.ByteString.CopyFromUtf8((string)value); - break; - case "type": - attr_value.Type = _MakeType((TF_DataType)value, attr_def); - break; - case "bool": - attr_value.B = (bool)value; - break; - case "shape": - attr_value.Shape = value == null ? - attr_def.DefaultValue.Shape : - tensor_util.as_shape((long[])value); - break; - default: - throw new InvalidDataException($"attr_def.Type {attr_def.Type}"); - } + switch (attr_def.Type) + { + case "string": + attr_value.S = Google.Protobuf.ByteString.CopyFromUtf8((string)value); + break; + case "type": + attr_value.Type = _MakeType((TF_DataType)value, attr_def); + break; + case "bool": + attr_value.B = (bool)value; + break; + case "shape": + attr_value.Shape = value == null ? + attr_def.DefaultValue.Shape : + tensor_util.as_shape((long[])value); + break; + default: + throw new InvalidDataException($"attr_def.Type {attr_def.Type}"); + } - attr_protos[key] = attr_value; - } + attr_protos[key] = attr_value; + } - // Determine output types (possibly using attrs) - var output_types = new List(); + // Determine output types (possibly using attrs) + var output_types = new List(); - foreach (var arg in op_def.OutputArg) - { - if (!String.IsNullOrEmpty(arg.NumberAttr)) + foreach (var arg in op_def.OutputArg) { + if (!String.IsNullOrEmpty(arg.NumberAttr)) + { + } + else if (!String.IsNullOrEmpty(arg.TypeAttr)) + { + output_types.Add((TF_DataType)attr_protos[arg.TypeAttr].Type); + } } - else if (!String.IsNullOrEmpty(arg.TypeAttr)) - { - output_types.Add((TF_DataType)attr_protos[arg.TypeAttr].Type); - } - } - // Add Op to graph - var op = g.create_op(op_type_name, inputs, output_types.ToArray(), - name: scope, - input_types: input_types.ToArray(), - attrs: attr_protos, - op_def: op_def); + // Add Op to graph + var op = g.create_op(op_type_name, inputs, output_types.ToArray(), + name: scope, + input_types: input_types.ToArray(), + attrs: attr_protos, + op_def: op_def); - return op; + return op; + } } public DataType _MakeType(TF_DataType v, AttrDef attr_def) diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj index fa0e171b..49f9ae1c 100644 --- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj +++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj @@ -4,9 +4,9 @@ netstandard2.0 TensorFlow.NET Tensorflow - 0.0.2 + 0.0.3 Haiping Chen - SciSharp.org + SciSharp STACK true Apache 2.0 https://github.com/SciSharp/TensorFlow.NET @@ -16,7 +16,7 @@ TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET Google's TensorFlow binding in .NET Standard. Docs: https://tensorflownet.readthedocs.io - 0.0.2.0 + 0.0.3.0 API updated 7.2 diff --git a/src/TensorFlowNET.Core/Tensors/TF_DataType.cs b/src/TensorFlowNET.Core/Tensors/TF_DataType.cs index 543fb4e4..b3e5b79b 100644 --- a/src/TensorFlowNET.Core/Tensors/TF_DataType.cs +++ b/src/TensorFlowNET.Core/Tensors/TF_DataType.cs @@ -4,6 +4,10 @@ using System.Text; namespace Tensorflow { + /// + /// TF_DataType holds the type for a scalar value. E.g., one slot in a tensor. + /// The enum values here are identical to corresponding values in types.proto. + /// public enum TF_DataType { DtInvalid = 0, @@ -30,6 +34,8 @@ namespace Tensorflow TF_RESOURCE = 20, TF_VARIANT = 21, TF_UINT32 = 22, - TF_UINT64 = 23 + TF_UINT64 = 23, + + DtDoubleRef = 102, // DT_DOUBLE_REF } } diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index e5f868b7..11f185f8 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -19,7 +19,10 @@ namespace Tensorflow public Graph Graph => op.Graph; public Operation op { get; } - public string name; + /// + /// The string name of this tensor. + /// + public string name => $"{(op == null ? "Operation was not named" : $"{op.Name}:{value_index}")}"; public int value_index { get; } @@ -222,7 +225,7 @@ namespace Tensorflow } } - return $"{name} {dtype} {rank} {string.Join(",", shape)}"; + return $"{name} {dtype.ToString()} {rank} {string.Join(",", shape)}"; } public void Dispose() diff --git a/src/TensorFlowNET.Core/Train/Optimizer.cs b/src/TensorFlowNET.Core/Train/Optimizer.cs index 5d83a15e..7385a927 100644 --- a/src/TensorFlowNET.Core/Train/Optimizer.cs +++ b/src/TensorFlowNET.Core/Train/Optimizer.cs @@ -17,6 +17,10 @@ namespace Tensorflow public string Name { get; set; } public double LearningRate { get; set; } public Tensor LearningRateTensor { get; set; } + public bool _use_locking; + public Dictionary _slots; + public Dictionary _non_slot_dict; + public Dictionary _deferred_slot_restorations; public Optimizer(double learning_rate, bool use_locking, string name = "") { @@ -24,6 +28,11 @@ namespace Tensorflow throw new NotImplementedException("Must specify the optimizer name"); Name = name; + _use_locking = use_locking; + // Dictionary of slots. + _slots = new Dictionary(); + _non_slot_dict = new Dictionary(); + _deferred_slot_restorations = new Dictionary(); } /// @@ -68,7 +77,7 @@ namespace Tensorflow break; } - var processors = var_list.Select(v => optimizer._get_processor(v)); + var processors = var_list.Select(v => optimizer._get_processor(v)).ToList(); var var_refs = processors.Select(x => x.target()).ToList(); gradients_impl.gradients(loss, var_refs, grad_ys: grad_loss, diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.cs b/src/TensorFlowNET.Core/Variables/RefVariable.cs index e70684b7..e8fa72a0 100644 --- a/src/TensorFlowNET.Core/Variables/RefVariable.cs +++ b/src/TensorFlowNET.Core/Variables/RefVariable.cs @@ -79,6 +79,17 @@ namespace Tensorflow // have an issue if these other variables aren't initialized first by // using their initialized_value() method. + var _initializer_op = gen_state_ops.assign(_variable, _initial_value, validate_shape).op; + + if (!String.IsNullOrEmpty(caching_device)) + { + + } + else + { + + } + ops.add_to_collections(collections, this); } } diff --git a/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs b/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs index 4ea475e2..e1585b42 100644 --- a/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs +++ b/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs @@ -1,4 +1,5 @@ -using System; +using NumSharp.Core; +using System; using System.Collections.Generic; using System.Text; @@ -33,5 +34,31 @@ namespace Tensorflow return new Tensor(_op, 0, dtype); } + + /// + /// Update 'ref' by assigning 'value' to it + /// + /// + /// + /// + /// + /// + public static Tensor assign(Tensor tensor, Tensor value, + bool validate_shape = true, + bool use_locking = true, + string name = "") + { + var keywords = new Dictionary(); + keywords.Add("ref", tensor); + keywords.Add("value", value); + keywords.Add("validate_shape", validate_shape); + keywords.Add("use_locking", use_locking); + + var _op = _op_def_lib._apply_op_helper("Assign", name: name, keywords: keywords); + + var _result = _op.outputs[0]; + var _inputs_flat = _op.inputs; + return _result; + } } } diff --git a/src/TensorFlowNET.Core/ops.name_scope.cs b/src/TensorFlowNET.Core/ops.name_scope.cs index 1fc70e46..3c7c124e 100644 --- a/src/TensorFlowNET.Core/ops.name_scope.cs +++ b/src/TensorFlowNET.Core/ops.name_scope.cs @@ -21,8 +21,6 @@ namespace Tensorflow _default_name = default_name; _values = values; _ctx = new Context(); - - _name_scope = __enter__(); } public string __enter__() @@ -38,8 +36,10 @@ namespace Tensorflow if (g == null) g = get_default_graph(); - - return g.name_scope(_name); ; + + _name_scope = g.name_scope(_name); + + return _name_scope; } public void Dispose() @@ -48,9 +48,13 @@ namespace Tensorflow g._name_stack = g.old_stack; } + /// + /// __enter__() + /// + /// public static implicit operator string(name_scope ns) { - return ns._name_scope; + return ns.__enter__(); } } } diff --git a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj index 149d04f0..c66a7461 100644 --- a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj +++ b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj @@ -7,7 +7,7 @@ - + diff --git a/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj b/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj index f0222bf9..16a5fed3 100644 --- a/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj +++ b/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj @@ -20,7 +20,7 @@ - +