Browse Source

remove global static Graph instance.

tags/v0.9
Oceania2018 6 years ago
parent
commit
c4a585c320
7 changed files with 76 additions and 22 deletions
  1. +40
    -0
      src/TensorFlowNET.Core/Graphs/DefaultGraphStack.cs
  2. +9
    -2
      src/TensorFlowNET.Core/Graphs/Graph.cs
  3. +5
    -6
      src/TensorFlowNET.Core/Layers/Layer.cs
  4. +12
    -0
      src/TensorFlowNET.Core/Variables/variable_scope.py.cs
  5. +2
    -1
      src/TensorFlowNET.Core/ops.name_scope.cs
  6. +6
    -10
      src/TensorFlowNET.Core/ops.py.cs
  7. +2
    -3
      test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs

+ 40
- 0
src/TensorFlowNET.Core/Graphs/DefaultGraphStack.cs View File

@@ -0,0 +1,40 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;

namespace Tensorflow
{
public class DefaultGraphStack
{
List<StackModel> stack = new List<StackModel>();

public void set_controller(Graph @default)
{
if (!stack.Exists(x => x.Graph == @default))
stack.Add(new StackModel { Graph = @default, IsDefault = true });

foreach (var s in stack)
s.IsDefault = s.Graph == @default;
}

public Graph get_controller()
{
if (stack.Count == 0)
stack.Add(new StackModel { Graph = tf.Graph(), IsDefault = true });

return stack.First(x => x.IsDefault).Graph;
}

public void reset()
{
stack.Clear();
}
}

public class StackModel
{
public Graph Graph { get; set; }
public bool IsDefault { get; set; }
}
}

+ 9
- 2
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -87,7 +87,7 @@ namespace Tensorflow
private Dictionary<string, object> _collections = new Dictionary<string, object>();

public bool building_function;
public Graph()
{
_handle = c_api.TF_NewGraph();
@@ -113,7 +113,14 @@ namespace Tensorflow
return _as_graph_element_locked(obj, allow_tensor, allow_operation);
}

public Graph as_default() => ops.set_default_graph(this);
/// <summary>
/// Returns a context manager that makes this `Graph` the default graph.
/// </summary>
/// <returns></returns>
public Graph as_default()
{
return ops.set_default_graph(this);
}

private Tensor _as_graph_element(object obj)
{


+ 5
- 6
src/TensorFlowNET.Core/Layers/Layer.cs View File

@@ -172,13 +172,12 @@ namespace Tensorflow.Layers
}
else
{
with(tf.variable_scope(scope, default_name: _base_name),
captured_scope =>
{
_scope = captured_scope;
});
with(tf.variable_scope(scope, default_name: _base_name), captured_scope =>
{
// convert variable_scope to VariableScope
_scope = captured_scope;
});
}

}
}
}


+ 12
- 0
src/TensorFlowNET.Core/Variables/variable_scope.py.cs View File

@@ -26,6 +26,7 @@ namespace Tensorflow
private bool? _reuse;
bool _in_graph_mode;
protected Graph _graph;
bool _building_function;

public variable_scope(string name,
string default_name = "",
@@ -70,6 +71,17 @@ namespace Tensorflow

public void __enter__()
{
// If the default graph is building a function, then we should not replace it
// with the cached graph.
if (ops.get_default_graph().building_function)
_building_function = true;
else
_building_function = false;
if (_in_graph_mode && !_building_function)
{
_graph.as_default();
}

_scope = _enter_scope_uncached();
}



+ 2
- 1
src/TensorFlowNET.Core/ops.name_scope.cs View File

@@ -54,12 +54,13 @@ namespace Tensorflow
public void Dispose()
{
var g = get_default_graph();
// Console.WriteLine($"name_scope: {g._name_stack} -> {old_stack}");
g._name_stack = old_stack;
// Console.WriteLine($"name_scope: {g._name_stack} -> {old_stack}");
}

public void __exit__()
{

}

/// <summary>


+ 6
- 10
src/TensorFlowNET.Core/ops.py.cs View File

@@ -50,7 +50,8 @@ namespace Tensorflow
return get_default_graph().get_collection_ref(key);
}

private static Graph default_graph;
public static DefaultGraphStack default_graph_stack = new DefaultGraphStack();

/// <summary>
/// Returns the default graph for the current thread.
///
@@ -68,15 +69,13 @@ namespace Tensorflow
{
//TODO: original source indicates there should be a _default_graph_stack!
//return _default_graph_stack.get_default()
if (default_graph == null)
default_graph = tf.Graph();
return default_graph;
return default_graph_stack.get_controller();
}
public static Graph set_default_graph(Graph graph)
{
//TODO: original source does not have a 'set_default_graph' and indicates there should be a _default_graph_stack!
default_graph = graph;
return default_graph;
default_graph_stack.set_controller(graph);
return default_graph_stack.get_controller();
}

/// <summary>
@@ -96,10 +95,7 @@ namespace Tensorflow
// throw new InvalidOperationException("Do not use tf.reset_default_graph() to clear " +
// "nested graphs. If you need a cleared graph, " +
// "exit the nesting and create a new graph.");
//_default_graph_stack.reset();
if (default_graph!=null)
default_graph.Dispose();
default_graph = tf.Graph();
default_graph_stack.reset();
}

public static Graph _get_graph_from_inputs(params Tensor[] op_input_list)


+ 2
- 3
test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs View File

@@ -195,7 +195,7 @@ namespace TensorFlowNET.Examples
return graph;
}

private bool RunWithImportedGraph(Session sess, Graph graph)
private bool Train(Session sess, Graph graph)
{
var stopwatch = Stopwatch.StartNew();

@@ -274,8 +274,7 @@ namespace TensorFlowNET.Examples
{
var graph = IsImportingGraph ? ImportGraph() : BuildGraph();

return with(tf.Session(graph), sess
=> RunWithImportedGraph(sess, graph));
return with(tf.Session(graph), sess => Train(sess, graph));
}

public bool Predict()


Loading…
Cancel
Save