From 02871910c187d47a91cc61bb5da2d50ef849edee Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Fri, 11 Oct 2019 07:21:15 -0500 Subject: [PATCH] create ResourceVariable, still with buggy. --- TensorFlow.NET.sln | 66 +++++++++++++++++++ src/TensorFlowNET.Core/Binding.Util.cs | 1 + src/TensorFlowNET.Core/Graphs/Graph.cs | 4 +- .../Keras/Utils/base_layer_utils.cs | 3 +- .../Operations/gen_resource_variable_ops.cs | 7 ++ .../Operations/resource_variable_ops.cs | 23 +++++++ .../TensorFlowNET.Core.csproj | 15 +++-- .../Variables/RefVariable.cs | 2 +- .../Variables/ResourceVariable.cs | 55 +++++++++++++++- .../Variables/VariableV1.cs | 4 ++ .../Variables/variables.py.cs | 54 +++++++++++++++ src/TensorFlowNET.Core/ops.cs | 2 +- .../TensorFlowNET.Hub.csproj | 2 +- .../Keras/EmbeddingTest.cs | 3 + 14 files changed, 229 insertions(+), 12 deletions(-) diff --git a/TensorFlow.NET.sln b/TensorFlow.NET.sln index 96a8af5c..16f524a4 100644 --- a/TensorFlow.NET.sln +++ b/TensorFlow.NET.sln @@ -15,36 +15,102 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Hub", "src\Te EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Datasets", "src\TensorFlowNET.Datasets\TensorFlowNET.Datasets.csproj", "{494D6CAD-2C0D-4C0B-90E2-B097DB039383}" EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "NumSharp.Core", "..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj", "{9249BCC4-3FEB-4EF5-8AB9-789FFE4040B4}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU + Debug|x64 = Debug|x64 + Publish|Any CPU = Publish|Any CPU + Publish|x64 = Publish|x64 Release|Any CPU = Release|Any CPU + Release|x64 = Release|x64 EndGlobalSection GlobalSection(ProjectConfigurationPlatforms) = postSolution {029A8CF1-CF95-4DCB-98AA-9D3D96A83B3E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {029A8CF1-CF95-4DCB-98AA-9D3D96A83B3E}.Debug|Any CPU.Build.0 = Debug|Any CPU + {029A8CF1-CF95-4DCB-98AA-9D3D96A83B3E}.Debug|x64.ActiveCfg = Debug|Any CPU + {029A8CF1-CF95-4DCB-98AA-9D3D96A83B3E}.Debug|x64.Build.0 = Debug|Any CPU + {029A8CF1-CF95-4DCB-98AA-9D3D96A83B3E}.Publish|Any CPU.ActiveCfg = Release|Any CPU + {029A8CF1-CF95-4DCB-98AA-9D3D96A83B3E}.Publish|Any CPU.Build.0 = Release|Any CPU + {029A8CF1-CF95-4DCB-98AA-9D3D96A83B3E}.Publish|x64.ActiveCfg = Release|Any CPU + {029A8CF1-CF95-4DCB-98AA-9D3D96A83B3E}.Publish|x64.Build.0 = Release|Any CPU {029A8CF1-CF95-4DCB-98AA-9D3D96A83B3E}.Release|Any CPU.ActiveCfg = Release|Any CPU {029A8CF1-CF95-4DCB-98AA-9D3D96A83B3E}.Release|Any CPU.Build.0 = Release|Any CPU + {029A8CF1-CF95-4DCB-98AA-9D3D96A83B3E}.Release|x64.ActiveCfg = Release|Any CPU + {029A8CF1-CF95-4DCB-98AA-9D3D96A83B3E}.Release|x64.Build.0 = Release|Any CPU {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug|Any CPU.Build.0 = Debug|Any CPU + {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug|x64.ActiveCfg = Debug|Any CPU + {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug|x64.Build.0 = Debug|Any CPU + {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Publish|Any CPU.ActiveCfg = Release|Any CPU + {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Publish|Any CPU.Build.0 = Release|Any CPU + {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Publish|x64.ActiveCfg = Release|Any CPU + {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Publish|x64.Build.0 = Release|Any CPU {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|Any CPU.ActiveCfg = Release|Any CPU {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|Any CPU.Build.0 = Release|Any CPU + {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|x64.ActiveCfg = Release|Any CPU + {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|x64.Build.0 = Release|Any CPU {D03F94CF-B283-4730-B177-21A57641061F}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {D03F94CF-B283-4730-B177-21A57641061F}.Debug|Any CPU.Build.0 = Debug|Any CPU + {D03F94CF-B283-4730-B177-21A57641061F}.Debug|x64.ActiveCfg = Debug|Any CPU + {D03F94CF-B283-4730-B177-21A57641061F}.Debug|x64.Build.0 = Debug|Any CPU + {D03F94CF-B283-4730-B177-21A57641061F}.Publish|Any CPU.ActiveCfg = Release|Any CPU + {D03F94CF-B283-4730-B177-21A57641061F}.Publish|Any CPU.Build.0 = Release|Any CPU + {D03F94CF-B283-4730-B177-21A57641061F}.Publish|x64.ActiveCfg = Release|Any CPU + {D03F94CF-B283-4730-B177-21A57641061F}.Publish|x64.Build.0 = Release|Any CPU {D03F94CF-B283-4730-B177-21A57641061F}.Release|Any CPU.ActiveCfg = Release|Any CPU {D03F94CF-B283-4730-B177-21A57641061F}.Release|Any CPU.Build.0 = Release|Any CPU + {D03F94CF-B283-4730-B177-21A57641061F}.Release|x64.ActiveCfg = Release|Any CPU + {D03F94CF-B283-4730-B177-21A57641061F}.Release|x64.Build.0 = Release|Any CPU {904472F8-40E1-4650-AA6F-C7F209B3691B}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {904472F8-40E1-4650-AA6F-C7F209B3691B}.Debug|Any CPU.Build.0 = Debug|Any CPU + {904472F8-40E1-4650-AA6F-C7F209B3691B}.Debug|x64.ActiveCfg = Debug|Any CPU + {904472F8-40E1-4650-AA6F-C7F209B3691B}.Debug|x64.Build.0 = Debug|Any CPU + {904472F8-40E1-4650-AA6F-C7F209B3691B}.Publish|Any CPU.ActiveCfg = Release|Any CPU + {904472F8-40E1-4650-AA6F-C7F209B3691B}.Publish|Any CPU.Build.0 = Release|Any CPU + {904472F8-40E1-4650-AA6F-C7F209B3691B}.Publish|x64.ActiveCfg = Release|Any CPU + {904472F8-40E1-4650-AA6F-C7F209B3691B}.Publish|x64.Build.0 = Release|Any CPU {904472F8-40E1-4650-AA6F-C7F209B3691B}.Release|Any CPU.ActiveCfg = Release|Any CPU {904472F8-40E1-4650-AA6F-C7F209B3691B}.Release|Any CPU.Build.0 = Release|Any CPU + {904472F8-40E1-4650-AA6F-C7F209B3691B}.Release|x64.ActiveCfg = Release|Any CPU + {904472F8-40E1-4650-AA6F-C7F209B3691B}.Release|x64.Build.0 = Release|Any CPU {4EAFAE19-C832-47C6-B01E-0F4268C9072C}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {4EAFAE19-C832-47C6-B01E-0F4268C9072C}.Debug|Any CPU.Build.0 = Debug|Any CPU + {4EAFAE19-C832-47C6-B01E-0F4268C9072C}.Debug|x64.ActiveCfg = Debug|Any CPU + {4EAFAE19-C832-47C6-B01E-0F4268C9072C}.Debug|x64.Build.0 = Debug|Any CPU + {4EAFAE19-C832-47C6-B01E-0F4268C9072C}.Publish|Any CPU.ActiveCfg = Release|Any CPU + {4EAFAE19-C832-47C6-B01E-0F4268C9072C}.Publish|Any CPU.Build.0 = Release|Any CPU + {4EAFAE19-C832-47C6-B01E-0F4268C9072C}.Publish|x64.ActiveCfg = Release|Any CPU + {4EAFAE19-C832-47C6-B01E-0F4268C9072C}.Publish|x64.Build.0 = Release|Any CPU {4EAFAE19-C832-47C6-B01E-0F4268C9072C}.Release|Any CPU.ActiveCfg = Release|Any CPU {4EAFAE19-C832-47C6-B01E-0F4268C9072C}.Release|Any CPU.Build.0 = Release|Any CPU + {4EAFAE19-C832-47C6-B01E-0F4268C9072C}.Release|x64.ActiveCfg = Release|Any CPU + {4EAFAE19-C832-47C6-B01E-0F4268C9072C}.Release|x64.Build.0 = Release|Any CPU {494D6CAD-2C0D-4C0B-90E2-B097DB039383}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {494D6CAD-2C0D-4C0B-90E2-B097DB039383}.Debug|Any CPU.Build.0 = Debug|Any CPU + {494D6CAD-2C0D-4C0B-90E2-B097DB039383}.Debug|x64.ActiveCfg = Debug|Any CPU + {494D6CAD-2C0D-4C0B-90E2-B097DB039383}.Debug|x64.Build.0 = Debug|Any CPU + {494D6CAD-2C0D-4C0B-90E2-B097DB039383}.Publish|Any CPU.ActiveCfg = Release|Any CPU + {494D6CAD-2C0D-4C0B-90E2-B097DB039383}.Publish|Any CPU.Build.0 = Release|Any CPU + {494D6CAD-2C0D-4C0B-90E2-B097DB039383}.Publish|x64.ActiveCfg = Release|Any CPU + {494D6CAD-2C0D-4C0B-90E2-B097DB039383}.Publish|x64.Build.0 = Release|Any CPU {494D6CAD-2C0D-4C0B-90E2-B097DB039383}.Release|Any CPU.ActiveCfg = Release|Any CPU {494D6CAD-2C0D-4C0B-90E2-B097DB039383}.Release|Any CPU.Build.0 = Release|Any CPU + {494D6CAD-2C0D-4C0B-90E2-B097DB039383}.Release|x64.ActiveCfg = Release|Any CPU + {494D6CAD-2C0D-4C0B-90E2-B097DB039383}.Release|x64.Build.0 = Release|Any CPU + {9249BCC4-3FEB-4EF5-8AB9-789FFE4040B4}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {9249BCC4-3FEB-4EF5-8AB9-789FFE4040B4}.Debug|Any CPU.Build.0 = Debug|Any CPU + {9249BCC4-3FEB-4EF5-8AB9-789FFE4040B4}.Debug|x64.ActiveCfg = Debug|x64 + {9249BCC4-3FEB-4EF5-8AB9-789FFE4040B4}.Debug|x64.Build.0 = Debug|x64 + {9249BCC4-3FEB-4EF5-8AB9-789FFE4040B4}.Publish|Any CPU.ActiveCfg = Publish|Any CPU + {9249BCC4-3FEB-4EF5-8AB9-789FFE4040B4}.Publish|Any CPU.Build.0 = Publish|Any CPU + {9249BCC4-3FEB-4EF5-8AB9-789FFE4040B4}.Publish|x64.ActiveCfg = Publish|x64 + {9249BCC4-3FEB-4EF5-8AB9-789FFE4040B4}.Publish|x64.Build.0 = Publish|x64 + {9249BCC4-3FEB-4EF5-8AB9-789FFE4040B4}.Release|Any CPU.ActiveCfg = Release|Any CPU + {9249BCC4-3FEB-4EF5-8AB9-789FFE4040B4}.Release|Any CPU.Build.0 = Release|Any CPU + {9249BCC4-3FEB-4EF5-8AB9-789FFE4040B4}.Release|x64.ActiveCfg = Release|x64 + {9249BCC4-3FEB-4EF5-8AB9-789FFE4040B4}.Release|x64.Build.0 = Release|x64 EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE diff --git a/src/TensorFlowNET.Core/Binding.Util.cs b/src/TensorFlowNET.Core/Binding.Util.cs index dcf191ed..334f4f74 100644 --- a/src/TensorFlowNET.Core/Binding.Util.cs +++ b/src/TensorFlowNET.Core/Binding.Util.cs @@ -113,6 +113,7 @@ namespace Tensorflow } } + [DebuggerStepThrough] [DebuggerNonUserCode()] // with "Just My Code" enabled this lets the debugger break at the origin of the exception public static void tf_with(T py, Action action) where T : IObjectLife { diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index 7119a4ad..0c43582d 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -266,7 +266,7 @@ namespace Tensorflow name = op_type; // If a names ends with a '/' it is a "name scope" and we use it as-is, // after removing the trailing '/'. - name = name.EndsWith("/") ? ops._name_from_scope_name(name) : unique_name(name); + name = name.EndsWith("/") ? ops.name_from_scope_name(name) : unique_name(name); var node_def = ops._NodeDef(op_type, name, device: "", attrs: attrs); var input_ops = inputs.Select(x => x.op).ToArray(); @@ -341,7 +341,7 @@ namespace Tensorflow if (string.IsNullOrEmpty(name)) new_stack = ""; else if (name.EndsWith("/")) - new_stack = ops._name_from_scope_name(name); + new_stack = ops.name_from_scope_name(name); else new_stack = unique_name(name); diff --git a/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs b/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs index 477cc56f..6e2dc745 100644 --- a/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs +++ b/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs @@ -49,7 +49,8 @@ namespace Tensorflow.Keras.Utils var v = tf.VariableV1(init_val, use_resource: use_resource, dtype: dtype, - shape: shape); + shape: shape, + name: name); return v; } diff --git a/src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs b/src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs index 304a5b55..7b00b604 100644 --- a/src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs @@ -26,5 +26,12 @@ namespace Tensorflow return _op; } + + public static Tensor var_is_initialized_op(Tensor resource, string name = null) + { + var _op = _op_def_lib._apply_op_helper("VarIsInitializedOp", name, new { resource }); + + return _op; + } } } diff --git a/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs b/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs index 41bd0ddf..b301063c 100644 --- a/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs +++ b/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs @@ -74,6 +74,29 @@ namespace Tensorflow return var is ResourceVariable; } + /// + /// Creates a variable handle with information to do shape inference. + /// + /// + /// + /// + /// + /// + /// + public static Tensor eager_safe_variable_handle(Tensor initial_value, TensorShape shape, + string shared_name, string name, bool graph_mode) + { + var dtype = initial_value.dtype.as_base_dtype(); + return variable_handle_from_shape_and_dtype( + shape, dtype, shared_name, name, graph_mode, initial_value); + } + + public static Tensor variable_handle_from_shape_and_dtype(TensorShape shape, TF_DataType dtype, + string shared_name, string name, bool graph_mode, Tensor extra_handle_data = null) + { + throw new NotImplementedException(""); + } + /// /// Represents a future for a read of a variable. /// Pretends to be the tensor if anyone looks. diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj index 12a4c5f3..33914c3a 100644 --- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj +++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj @@ -5,7 +5,7 @@ TensorFlow.NET Tensorflow 1.14.0 - 0.11.5 + 0.11.6 Haiping Chen, Meinrad Recheis, Eli Belash SciSharp STACK true @@ -17,7 +17,7 @@ TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET, C# Google's TensorFlow full binding in .NET Standard. Docs: https://tensorflownet.readthedocs.io - 0.11.5.0 + 0.11.6.0 Changes since v0.10.0: 1. Upgrade NumSharp to v0.20.3. 2. Add DisposableObject class to manage object lifetime. @@ -29,9 +29,11 @@ Docs: https://tensorflownet.readthedocs.io 8. Add tf.random_normal, tf.constant, tf.pad, tf.shape, tf.image.resize_nearest_neighbor. 9. MultiThread is safe. 10. Support n-dim indexing for tensor. -11. Add RegisterNoGradient +11. Add RegisterNoGradients +12. Add CumsumGrad, BroadcastToGrad. +13. Return VariableV1 instead of RefVariable. 7.3 - 0.11.5.0 + 0.11.6.0 LICENSE true true @@ -63,7 +65,6 @@ Docs: https://tensorflownet.readthedocs.io - @@ -71,4 +72,8 @@ Docs: https://tensorflownet.readthedocs.io + + + + diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.cs b/src/TensorFlowNET.Core/Variables/RefVariable.cs index 35e0da87..4b0a35fb 100644 --- a/src/TensorFlowNET.Core/Variables/RefVariable.cs +++ b/src/TensorFlowNET.Core/Variables/RefVariable.cs @@ -143,7 +143,7 @@ namespace Tensorflow // Use attr_scope and device(None) to simulate the behavior of // colocate_with when the variable we want to colocate with doesn't // yet exist. - string true_name = ops._name_from_scope_name(name); + string true_name = ops.name_from_scope_name(name); var attr = new AttrValue { List = new AttrValue.Types.ListValue() diff --git a/src/TensorFlowNET.Core/Variables/ResourceVariable.cs b/src/TensorFlowNET.Core/Variables/ResourceVariable.cs index b548a50f..83774734 100644 --- a/src/TensorFlowNET.Core/Variables/ResourceVariable.cs +++ b/src/TensorFlowNET.Core/Variables/ResourceVariable.cs @@ -14,6 +14,7 @@ limitations under the License. ******************************************************************************/ +using Google.Protobuf; using System; using System.Collections.Generic; using static Tensorflow.Binding; @@ -92,7 +93,59 @@ namespace Tensorflow var init_from_fn = initial_value.GetType().Name == "Func`1"; if(collections == null) collections = new List() { tf.GraphKeys.GLOBAL_VARIABLES }; - + _trainable = trainable; + _graph_key = ops.get_default_graph().graph_key; + + ops.init_scope(); + _in_graph_mode = true; + tf_with(ops.name_scope(name, "Variable"), scope => + { + name = scope; + var handle_name = ops.name_from_scope_name(name); + var shared_name = handle_name; + var unique_id = shared_name; + + var attr = new AttrValue(); + attr.List = new AttrValue.Types.ListValue(); + attr.List.S.Add(ByteString.CopyFromUtf8($"loc:{handle_name}")); + tf_with(ops.name_scope("Initializer"), delegate + { + initial_value = ops.convert_to_tensor(init_from_fn ? (initial_value as Func)() : initial_value, + name: "initial_value", + dtype: dtype); + }); + _shape = shape ?? (initial_value as Tensor).TensorShape; + _handle = resource_variable_ops.eager_safe_variable_handle( + initial_value: _initial_value, + shape: _shape, + shared_name: shared_name, + name: name, + graph_mode: _in_graph_mode); + _unique_id = unique_id; + _initial_value = initial_value as Tensor; + _handle_name = handle_name + ":0"; + _dtype = _initial_value.dtype.as_base_dtype(); + // _constraint = constraint; + + if (_in_graph_mode) + { + tf_with(ops.name_scope("IsInitialized"), delegate + { + _is_initialized_op = gen_resource_variable_ops.var_is_initialized_op(_handle); + }); + if(initial_value != null) + { + tf_with(ops.name_scope("Assign"), scope1 => + { + string n = scope1; + _initializer_op = gen_resource_variable_ops.assign_variable_op(_handle, + variables._try_guard_against_uninitialized_dependencies(name, _initial_value), + name: n); + }); + } + } + }); + throw new NotImplementedException(""); } diff --git a/src/TensorFlowNET.Core/Variables/VariableV1.cs b/src/TensorFlowNET.Core/Variables/VariableV1.cs index e1247f8d..8f873291 100644 --- a/src/TensorFlowNET.Core/Variables/VariableV1.cs +++ b/src/TensorFlowNET.Core/Variables/VariableV1.cs @@ -35,7 +35,11 @@ namespace Tensorflow public virtual Operation op { get; } public virtual Operation initializer { get; } public Tensor _variable; + protected string _graph_key; public Graph graph => _variable.graph; + + public Tensor _is_initialized_op { get; set; } + public VariableV1(object initial_value = null, bool trainable = true, List collections = null, diff --git a/src/TensorFlowNET.Core/Variables/variables.py.cs b/src/TensorFlowNET.Core/Variables/variables.py.cs index 6e9d0e4c..d898a4aa 100644 --- a/src/TensorFlowNET.Core/Variables/variables.py.cs +++ b/src/TensorFlowNET.Core/Variables/variables.py.cs @@ -84,6 +84,60 @@ namespace Tensorflow return gen_control_flow_ops.no_op(name: name); } + public static Tensor _try_guard_against_uninitialized_dependencies(string name, Tensor initial_value) + { + return _safe_initial_value_from_tensor(name, initial_value, new Dictionary()); + } + + public static Tensor _safe_initial_value_from_tensor(string name, Tensor tensor, Dictionary op_cache) + { + var op = tensor.op; + Operation new_op = op_cache.ContainsKey(op.name) ? op_cache[op.name] : null; + if(new_op == null) + { + new_op = _safe_initial_value_from_op(name, op, op_cache); + op_cache[op.name] = new_op; + } + + return new_op.outputs[tensor.value_index]; + } + + /// + /// Replace dependencies on variables with their initialized values. + /// + /// + /// + /// + /// + public static Operation _safe_initial_value_from_op(string name, Operation op, Dictionary op_cache) + { + var op_type = op.node_def.Op; + if (op_type == "IsVariableInitialized" || + op_type == "VarIsInitializedOp" || + op_type == "ReadVariableOp") + return op; + + if(op_type == "Variable" || + op_type == "VariableV2" || + op_type == "VarHandleOp") + { + throw new NotImplementedException(""); + } + + // Recursively build initializer expressions for inputs. + bool modified = false; + var new_op_inputs = new List(); + foreach(Tensor op_input in op.inputs) + { + var new_op_input = _safe_initial_value_from_tensor(name, op_input, op_cache); + new_op_inputs.Add(new_op_input); + modified = modified || new_op_input != op_input; + } + + // If at least one input was modified, replace the op. + return op; + } + public static Tensor global_variables_initializer() { throw new NotImplementedException(); diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs index 846de1ea..d1e423c9 100644 --- a/src/TensorFlowNET.Core/ops.cs +++ b/src/TensorFlowNET.Core/ops.cs @@ -274,7 +274,7 @@ namespace Tensorflow return node_def; } - public static string _name_from_scope_name(string name) + public static string name_from_scope_name(string name) { if (name.EndsWith("/")) { diff --git a/src/TensorFlowNET.Hub/TensorFlowNET.Hub.csproj b/src/TensorFlowNET.Hub/TensorFlowNET.Hub.csproj index 27b5128b..10d27a5c 100644 --- a/src/TensorFlowNET.Hub/TensorFlowNET.Hub.csproj +++ b/src/TensorFlowNET.Hub/TensorFlowNET.Hub.csproj @@ -18,6 +18,6 @@ TensorFlow.Hub - + \ No newline at end of file diff --git a/test/TensorFlowNET.UnitTest/Keras/EmbeddingTest.cs b/test/TensorFlowNET.UnitTest/Keras/EmbeddingTest.cs index 896ad430..0168f22c 100644 --- a/test/TensorFlowNET.UnitTest/Keras/EmbeddingTest.cs +++ b/test/TensorFlowNET.UnitTest/Keras/EmbeddingTest.cs @@ -8,6 +8,9 @@ using NumSharp; namespace TensorFlowNET.UnitTest.Keras { + /// + /// https://www.tensorflow.org/versions/r1.14/api_docs/python/tf/keras/layers/Embedding + /// [TestClass] public class EmbeddingTest {