@@ -0,0 +1,40 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Linq; | |||||
using System.Text; | |||||
namespace Tensorflow | |||||
{ | |||||
public class DefaultGraphStack | |||||
{ | |||||
List<StackModel> stack = new List<StackModel>(); | |||||
public void set_controller(Graph @default) | |||||
{ | |||||
if (!stack.Exists(x => x.Graph == @default)) | |||||
stack.Add(new StackModel { Graph = @default, IsDefault = true }); | |||||
foreach (var s in stack) | |||||
s.IsDefault = s.Graph == @default; | |||||
} | |||||
public Graph get_controller() | |||||
{ | |||||
if (stack.Count == 0) | |||||
stack.Add(new StackModel { Graph = tf.Graph(), IsDefault = true }); | |||||
return stack.First(x => x.IsDefault).Graph; | |||||
} | |||||
public void reset() | |||||
{ | |||||
stack.Clear(); | |||||
} | |||||
} | |||||
public class StackModel | |||||
{ | |||||
public Graph Graph { get; set; } | |||||
public bool IsDefault { get; set; } | |||||
} | |||||
} |
@@ -87,7 +87,7 @@ namespace Tensorflow | |||||
private Dictionary<string, object> _collections = new Dictionary<string, object>(); | private Dictionary<string, object> _collections = new Dictionary<string, object>(); | ||||
public bool building_function; | public bool building_function; | ||||
public Graph() | public Graph() | ||||
{ | { | ||||
_handle = c_api.TF_NewGraph(); | _handle = c_api.TF_NewGraph(); | ||||
@@ -113,7 +113,14 @@ namespace Tensorflow | |||||
return _as_graph_element_locked(obj, allow_tensor, allow_operation); | return _as_graph_element_locked(obj, allow_tensor, allow_operation); | ||||
} | } | ||||
public Graph as_default() => ops.set_default_graph(this); | |||||
/// <summary> | |||||
/// Returns a context manager that makes this `Graph` the default graph. | |||||
/// </summary> | |||||
/// <returns></returns> | |||||
public Graph as_default() | |||||
{ | |||||
return ops.set_default_graph(this); | |||||
} | |||||
private Tensor _as_graph_element(object obj) | private Tensor _as_graph_element(object obj) | ||||
{ | { | ||||
@@ -172,13 +172,12 @@ namespace Tensorflow.Layers | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
with(tf.variable_scope(scope, default_name: _base_name), | |||||
captured_scope => | |||||
{ | |||||
_scope = captured_scope; | |||||
}); | |||||
with(tf.variable_scope(scope, default_name: _base_name), captured_scope => | |||||
{ | |||||
// convert variable_scope to VariableScope | |||||
_scope = captured_scope; | |||||
}); | |||||
} | } | ||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -26,6 +26,7 @@ namespace Tensorflow | |||||
private bool? _reuse; | private bool? _reuse; | ||||
bool _in_graph_mode; | bool _in_graph_mode; | ||||
protected Graph _graph; | protected Graph _graph; | ||||
bool _building_function; | |||||
public variable_scope(string name, | public variable_scope(string name, | ||||
string default_name = "", | string default_name = "", | ||||
@@ -70,6 +71,17 @@ namespace Tensorflow | |||||
public void __enter__() | public void __enter__() | ||||
{ | { | ||||
// If the default graph is building a function, then we should not replace it | |||||
// with the cached graph. | |||||
if (ops.get_default_graph().building_function) | |||||
_building_function = true; | |||||
else | |||||
_building_function = false; | |||||
if (_in_graph_mode && !_building_function) | |||||
{ | |||||
_graph.as_default(); | |||||
} | |||||
_scope = _enter_scope_uncached(); | _scope = _enter_scope_uncached(); | ||||
} | } | ||||
@@ -54,12 +54,13 @@ namespace Tensorflow | |||||
public void Dispose() | public void Dispose() | ||||
{ | { | ||||
var g = get_default_graph(); | var g = get_default_graph(); | ||||
// Console.WriteLine($"name_scope: {g._name_stack} -> {old_stack}"); | |||||
g._name_stack = old_stack; | g._name_stack = old_stack; | ||||
// Console.WriteLine($"name_scope: {g._name_stack} -> {old_stack}"); | |||||
} | } | ||||
public void __exit__() | public void __exit__() | ||||
{ | { | ||||
} | } | ||||
/// <summary> | /// <summary> | ||||
@@ -50,7 +50,8 @@ namespace Tensorflow | |||||
return get_default_graph().get_collection_ref(key); | return get_default_graph().get_collection_ref(key); | ||||
} | } | ||||
private static Graph default_graph; | |||||
public static DefaultGraphStack default_graph_stack = new DefaultGraphStack(); | |||||
/// <summary> | /// <summary> | ||||
/// Returns the default graph for the current thread. | /// Returns the default graph for the current thread. | ||||
/// | /// | ||||
@@ -68,15 +69,13 @@ namespace Tensorflow | |||||
{ | { | ||||
//TODO: original source indicates there should be a _default_graph_stack! | //TODO: original source indicates there should be a _default_graph_stack! | ||||
//return _default_graph_stack.get_default() | //return _default_graph_stack.get_default() | ||||
if (default_graph == null) | |||||
default_graph = tf.Graph(); | |||||
return default_graph; | |||||
return default_graph_stack.get_controller(); | |||||
} | } | ||||
public static Graph set_default_graph(Graph graph) | public static Graph set_default_graph(Graph graph) | ||||
{ | { | ||||
//TODO: original source does not have a 'set_default_graph' and indicates there should be a _default_graph_stack! | //TODO: original source does not have a 'set_default_graph' and indicates there should be a _default_graph_stack! | ||||
default_graph = graph; | |||||
return default_graph; | |||||
default_graph_stack.set_controller(graph); | |||||
return default_graph_stack.get_controller(); | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
@@ -96,10 +95,7 @@ namespace Tensorflow | |||||
// throw new InvalidOperationException("Do not use tf.reset_default_graph() to clear " + | // throw new InvalidOperationException("Do not use tf.reset_default_graph() to clear " + | ||||
// "nested graphs. If you need a cleared graph, " + | // "nested graphs. If you need a cleared graph, " + | ||||
// "exit the nesting and create a new graph."); | // "exit the nesting and create a new graph."); | ||||
//_default_graph_stack.reset(); | |||||
if (default_graph!=null) | |||||
default_graph.Dispose(); | |||||
default_graph = tf.Graph(); | |||||
default_graph_stack.reset(); | |||||
} | } | ||||
public static Graph _get_graph_from_inputs(params Tensor[] op_input_list) | public static Graph _get_graph_from_inputs(params Tensor[] op_input_list) | ||||
@@ -195,7 +195,7 @@ namespace TensorFlowNET.Examples | |||||
return graph; | return graph; | ||||
} | } | ||||
private bool RunWithImportedGraph(Session sess, Graph graph) | |||||
private bool Train(Session sess, Graph graph) | |||||
{ | { | ||||
var stopwatch = Stopwatch.StartNew(); | var stopwatch = Stopwatch.StartNew(); | ||||
@@ -274,8 +274,7 @@ namespace TensorFlowNET.Examples | |||||
{ | { | ||||
var graph = IsImportingGraph ? ImportGraph() : BuildGraph(); | var graph = IsImportingGraph ? ImportGraph() : BuildGraph(); | ||||
return with(tf.Session(graph), sess | |||||
=> RunWithImportedGraph(sess, graph)); | |||||
return with(tf.Session(graph), sess => Train(sess, graph)); | |||||
} | } | ||||
public bool Predict() | public bool Predict() | ||||