From 43273b37d563920c2854b8446eff418af073a89d Mon Sep 17 00:00:00 2001 From: haiping008 Date: Fri, 25 Jan 2019 10:06:20 -0600 Subject: [PATCH] fixed #108 --- src/TensorFlowNET.Core/Graphs/Graph.cs | 7 +------ src/TensorFlowNET.Core/Variables/RefVariable.cs | 2 ++ src/TensorFlowNET.Core/ops.py.cs | 9 ++++++++- src/TensorFlowNET.Core/tf.cs | 2 +- test/TensorFlowNET.UnitTest/VariableTest.cs | 2 +- 5 files changed, 13 insertions(+), 9 deletions(-) diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index dca54511..0bac4c8f 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -37,15 +37,10 @@ namespace Tensorflow { _handle = c_api.TF_NewGraph(); Status = new Status(); - } - - public Graph(IntPtr graph) - { - _handle = graph; - Status = new Status(); _nodes_by_id = new Dictionary(); _nodes_by_name = new Dictionary(); _names_in_use = new Dictionary(); + _graph_key = $"grap-key-{ops.uid()}/"; } public T as_graph_element(T obj, bool allow_tensor = true, bool allow_operation = true) diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.cs b/src/TensorFlowNET.Core/Variables/RefVariable.cs index d539cfa6..a864c2cd 100644 --- a/src/TensorFlowNET.Core/Variables/RefVariable.cs +++ b/src/TensorFlowNET.Core/Variables/RefVariable.cs @@ -28,6 +28,8 @@ namespace Tensorflow TF_DataType dtype = TF_DataType.DtInvalid) : base(initial_value, trainable, collections, validate_shape, caching_device, name, dtype) { + _in_graph_mode = true; + _init_from_args(initial_value, trainable, collections, validate_shape, caching_device, name, dtype); } diff --git a/src/TensorFlowNET.Core/ops.py.cs b/src/TensorFlowNET.Core/ops.py.cs index 5fe914f6..9eb07e52 100644 --- a/src/TensorFlowNET.Core/ops.py.cs +++ b/src/TensorFlowNET.Core/ops.py.cs @@ -174,9 +174,16 @@ namespace Tensorflow // outer_device_stack = None } + private static int uid_number = 0; + + /// + /// A unique (within this program execution) integer. + /// Not thread safe + /// + /// public static int uid() { - return 1; + return uid_number++; } } } diff --git a/src/TensorFlowNET.Core/tf.cs b/src/TensorFlowNET.Core/tf.cs index 5d63a411..4e28589c 100644 --- a/src/TensorFlowNET.Core/tf.cs +++ b/src/TensorFlowNET.Core/tf.cs @@ -15,7 +15,7 @@ namespace Tensorflow public static Context context; - public static Graph g = new Graph(c_api.TF_NewGraph()); + public static Graph g = new Graph(); public static RefVariable Variable(T data, string name = "", TF_DataType dtype = TF_DataType.DtInvalid) { diff --git a/test/TensorFlowNET.UnitTest/VariableTest.cs b/test/TensorFlowNET.UnitTest/VariableTest.cs index 979a84cf..31ff8b40 100644 --- a/test/TensorFlowNET.UnitTest/VariableTest.cs +++ b/test/TensorFlowNET.UnitTest/VariableTest.cs @@ -35,7 +35,7 @@ namespace TensorFlowNET.UnitTest using (var session = tf.Session()) { - session.run(model); + var sm = session.run(model); for(int i = 0; i < 5; i++) { var x1 = x + 1;