@@ -0,0 +1,105 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
using Tensorflow.Eager; | |||
namespace Tensorflow | |||
{ | |||
public partial class Graph | |||
{ | |||
public Context _control_flow_context; | |||
private Queue<_ControlDependenciesController> _graph_control_dependencies_stack = new Queue<_ControlDependenciesController>(); | |||
public Queue<_ControlDependenciesController> _control_dependencies_stack | |||
{ | |||
get | |||
{ | |||
return _graph_control_dependencies_stack; | |||
} | |||
set | |||
{ | |||
_graph_control_dependencies_stack = value; | |||
} | |||
} | |||
/// <summary> | |||
/// For an op that takes `input_ops` as inputs, compute control inputs. | |||
/// </summary> | |||
/// <param name="input_ops">The data input ops for an op to be created.</param> | |||
/// <returns>A list of control inputs for the op to be created.</returns> | |||
private Operation[] _control_dependencies_for_inputs(Operation[] input_ops) | |||
{ | |||
Operation[] ret = new Operation[0]; | |||
foreach(var controller in _control_dependencies_stack) | |||
{ | |||
bool dominated = false; | |||
// If any of the input_ops already depends on the inputs from controller, | |||
// we say that the new op is dominated (by that input), and we therefore | |||
// do not need to add control dependencies for this controller's inputs. | |||
foreach(var op in input_ops) | |||
{ | |||
if (controller.op_in_group(op)) | |||
{ | |||
dominated = true; | |||
break; | |||
} | |||
} | |||
if (!dominated) | |||
ret = controller.control_inputs.Where(x => !input_ops.Contains(x)).ToArray(); | |||
} | |||
return ret; | |||
} | |||
public _ControlDependenciesController control_dependencies(Operation[] control_inputs) | |||
{ | |||
if (control_inputs == null) | |||
return new _ControlDependenciesController(this, null); | |||
var control_ops = new List<Operation>(); | |||
foreach (var c in control_inputs) | |||
{ | |||
control_ops.Add(c); | |||
} | |||
return new _ControlDependenciesController(this, control_ops); | |||
} | |||
/// <summary> | |||
/// Returns the current control flow context. | |||
/// </summary> | |||
/// <returns>A context object.</returns> | |||
public Context _get_control_flow_context() | |||
{ | |||
return _control_flow_context; | |||
} | |||
/// <summary> | |||
/// Sets the current control flow context. | |||
/// </summary> | |||
/// <param name="ctx">a context object.</param> | |||
public void _set_control_flow_context(Context ctx) | |||
{ | |||
_control_flow_context = ctx; | |||
} | |||
public void _push_control_dependencies_controller(_ControlDependenciesController controller) | |||
{ | |||
_control_dependencies_stack.Enqueue(controller); | |||
} | |||
public void _pop_control_dependencies_controller(_ControlDependenciesController controller) | |||
{ | |||
_control_dependencies_stack.Dequeue(); | |||
} | |||
public void _record_op_seen_by_control_dependencies(Operation op) | |||
{ | |||
foreach (var controller in _control_dependencies_stack) | |||
controller.add_op(op); | |||
} | |||
} | |||
} |
@@ -142,19 +142,9 @@ namespace Tensorflow | |||
return op; | |||
} | |||
/// <summary> | |||
/// For an op that takes `input_ops` as inputs, compute control inputs. | |||
/// </summary> | |||
/// <param name="input_ops">The data input ops for an op to be created.</param> | |||
/// <returns>A list of control inputs for the op to be created.</returns> | |||
private Operation[] _control_dependencies_for_inputs(Operation[] input_ops) | |||
{ | |||
return new Operation[0]; | |||
} | |||
private void _create_op_helper(Operation op, bool compute_device = true) | |||
{ | |||
_record_op_seen_by_control_dependencies(op); | |||
} | |||
public void _add_op(Operation op) | |||
@@ -0,0 +1,80 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Eager; | |||
namespace Tensorflow | |||
{ | |||
/// <summary> | |||
/// Context manager for `control_dependencies()` | |||
/// </summary> | |||
public class _ControlDependenciesController : IPython | |||
{ | |||
private Graph _graph; | |||
private List<Operation> _control_inputs_val; | |||
private List<Operation> _seen_nodes; | |||
private Queue<_ControlDependenciesController> _old_stack; | |||
private bool _new_stack; | |||
private Context _old_control_flow_context; | |||
public Operation[] control_inputs => _control_inputs_val.ToArray(); | |||
public _ControlDependenciesController(Graph graph, List<Operation> control_inputs) | |||
{ | |||
_graph = graph; | |||
if (control_inputs == null) | |||
{ | |||
_control_inputs_val = new List<Operation>(); | |||
_new_stack = true; | |||
} | |||
else | |||
{ | |||
_control_inputs_val = control_inputs; | |||
_new_stack = false; | |||
} | |||
_seen_nodes = new List<Operation>(); | |||
} | |||
public void add_op(Operation op) | |||
{ | |||
_seen_nodes.Add(op); | |||
} | |||
public bool op_in_group(Operation op) | |||
{ | |||
return _seen_nodes.Contains(op); | |||
} | |||
public void __enter__() | |||
{ | |||
if (_new_stack) | |||
{ | |||
// Clear the control_dependencies graph. | |||
_old_stack = _graph._control_dependencies_stack; | |||
_graph._control_dependencies_stack = new Queue<_ControlDependenciesController>(); | |||
// Clear the control_flow_context too. | |||
_old_control_flow_context = _graph._get_control_flow_context(); | |||
_graph._set_control_flow_context(null); | |||
} | |||
_graph._push_control_dependencies_controller(this); | |||
} | |||
public void __exit__() | |||
{ | |||
_graph._pop_control_dependencies_controller(this); | |||
if (_new_stack) | |||
{ | |||
_graph._control_dependencies_stack = _old_stack; | |||
_graph._set_control_flow_context(_old_control_flow_context); | |||
} | |||
} | |||
public void Dispose() | |||
{ | |||
} | |||
} | |||
} |
@@ -39,6 +39,14 @@ namespace Tensorflow | |||
public int NumControlInputs => c_api.TF_OperationNumControlInputs(_handle); | |||
public Operation[] control_inputs | |||
{ | |||
get | |||
{ | |||
return GetControlInputs(); | |||
} | |||
} | |||
public unsafe Operation[] GetControlInputs() | |||
{ | |||
var control_inputs = new Operation[NumControlInputs]; | |||
@@ -49,15 +49,53 @@ namespace Tensorflow | |||
c_api.TF_FinishOperation(desc, status); | |||
} | |||
public Operation(NodeDef node_def, Graph g, List<Tensor> inputs = null, TF_DataType[] output_types = null, object control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null) | |||
/// <summary> | |||
/// Creates an `Operation`. | |||
/// </summary> | |||
/// <param name="node_def">`node_def_pb2.NodeDef`. `NodeDef` for the `Operation`.</param> | |||
/// <param name="g">`Graph`. The parent graph.</param> | |||
/// <param name="inputs">list of `Tensor` objects. The inputs to this `Operation`.</param> | |||
/// <param name="output_types">list of `DType` objects.</param> | |||
/// <param name="control_inputs"> | |||
/// list of operations or tensors from which to have a | |||
/// control dependency. | |||
/// </param> | |||
/// <param name="input_types"> | |||
/// List of `DType` objects representing the | |||
/// types of the tensors accepted by the `Operation`. By default | |||
/// uses `[x.dtype.base_dtype for x in inputs]`. Operations that expect | |||
/// reference-typed inputs must specify these explicitly. | |||
/// </param> | |||
/// <param name="original_op"></param> | |||
/// <param name="op_def"></param> | |||
public Operation(NodeDef node_def, Graph g, List<Tensor> inputs = null, TF_DataType[] output_types = null, Operation[] control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null) | |||
{ | |||
Graph = g; | |||
// Build the list of control inputs. | |||
var control_input_ops = new List<Operation>(); | |||
if(control_inputs != null) | |||
{ | |||
foreach(var c in control_inputs) | |||
{ | |||
switch (c) | |||
{ | |||
case Operation c1: | |||
control_input_ops.Add(c1); | |||
break; | |||
default: | |||
throw new NotImplementedException($"Control input must be an Operation, a Tensor, or IndexedSlices: {c}"); | |||
} | |||
} | |||
} | |||
// This will be set by self.inputs. | |||
_id_value = Graph._next_id(); | |||
if(op_def == null) | |||
op_def = g.GetOpDef(node_def.Op); | |||
_handle = ops._create_c_op(g, node_def, inputs); | |||
_handle = ops._create_c_op(g, node_def, inputs, control_input_ops.ToArray()); | |||
output_types = new TF_DataType[NumOutputs]; | |||
@@ -34,6 +34,14 @@ namespace Tensorflow | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TF_AddInput(IntPtr desc, TF_Output input); | |||
/// <summary> | |||
/// Call once per control input to `desc`. | |||
/// </summary> | |||
/// <param name="desc">TF_OperationDescription*</param> | |||
/// <param name="input">TF_Operation*</param> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TF_AddControlInput(IntPtr desc, IntPtr input); | |||
/// <summary> | |||
/// For inputs that take a list of tensors. | |||
/// inputs must point to TF_Output[num_inputs]. | |||
@@ -13,9 +13,8 @@ namespace Tensorflow | |||
{ | |||
name = namescope; | |||
var ops_on_device = new Dictionary<string, Operation[]>(); | |||
// Sorts *inputs according to their devices. | |||
var ops_on_device = new Dictionary<string, Operation[]>(); | |||
foreach (var inp in inputs) | |||
{ | |||
ops_on_device[inp.Device] = new Operation[] { inp }; | |||
@@ -24,7 +23,9 @@ namespace Tensorflow | |||
// 1-level tree. The root node is the returned NoOp node. | |||
if (ops_on_device.Count == 1) | |||
{ | |||
return _GroupControlDeps(ops_on_device.Keys.First(), ops_on_device.Values.First(), name); | |||
var dev = ops_on_device.Keys.First(); | |||
var deps = ops_on_device.Values.First(); | |||
return _GroupControlDeps(dev, deps, name); | |||
} | |||
// 2-level tree. The root node is the returned NoOp node. | |||
@@ -35,12 +36,21 @@ namespace Tensorflow | |||
private static Operation _GroupControlDeps(string dev, Operation[] deps, string name = "") | |||
{ | |||
if (string.IsNullOrEmpty(dev)) | |||
Operation result = null; | |||
Python.with(ops.control_dependencies(deps), delegate | |||
{ | |||
return gen_control_flow_ops.no_op(name); | |||
} | |||
if (string.IsNullOrEmpty(dev)) | |||
{ | |||
result = gen_control_flow_ops.no_op(name); | |||
} | |||
else | |||
{ | |||
result = gen_control_flow_ops.no_op(name); | |||
} | |||
}); | |||
return null; | |||
return result; | |||
} | |||
} | |||
} |
@@ -13,5 +13,30 @@ namespace Tensorflow | |||
{ | |||
Console.WriteLine(obj.ToString()); | |||
} | |||
public static void with(IPython py, Action action) | |||
{ | |||
try | |||
{ | |||
py.__enter__(); | |||
action(); | |||
} | |||
catch (Exception ex) | |||
{ | |||
throw ex; | |||
} | |||
finally | |||
{ | |||
py.__exit__(); | |||
py.Dispose(); | |||
} | |||
} | |||
} | |||
public interface IPython : IDisposable | |||
{ | |||
void __enter__(); | |||
void __exit__(); | |||
} | |||
} |
@@ -20,5 +20,10 @@ namespace Tensorflow | |||
{ | |||
return var._AsTensor(); | |||
} | |||
public static implicit operator RefVariable(Tensor var) | |||
{ | |||
return null; | |||
} | |||
} | |||
} |
@@ -166,5 +166,10 @@ namespace Tensorflow | |||
// Recursively build initializer expressions for inputs. | |||
return op; | |||
} | |||
public override string ToString() | |||
{ | |||
return $"tf.Variable '{name}' shape={shape} dtype={dtype}"; | |||
} | |||
} | |||
} |
@@ -78,7 +78,29 @@ namespace Tensorflow | |||
} | |||
} | |||
public static unsafe IntPtr _create_c_op(Graph graph, NodeDef node_def, List<Tensor> inputs) | |||
/// <summary> | |||
/// Wrapper for `Graph.control_dependencies()` using the default graph. | |||
/// </summary> | |||
/// <param name="control_inputs"></param> | |||
public static _ControlDependenciesController control_dependencies(Operation[] control_inputs) | |||
{ | |||
return get_default_graph().control_dependencies(control_inputs); | |||
} | |||
/// <summary> | |||
/// Creates a TF_Operation. | |||
/// </summary> | |||
/// <param name="graph">a `Graph`.</param> | |||
/// <param name="node_def">`node_def_pb2.NodeDef` for the operation to create.</param> | |||
/// <param name="inputs"> | |||
/// A list of `Tensor`s (corresponding to scalar inputs) and lists of | |||
/// `Tensor`s (corresponding to sequence inputs, e.g. "int64 * N", | |||
/// "list(int64)"). The length of the list should be equal to the number of | |||
/// inputs specified by this operation's op def. | |||
/// </param> | |||
/// <param name="control_inputs">A list of `Operation`s to set as control dependencies.</param> | |||
/// <returns>A wrapped TF_Operation*.</returns> | |||
public static IntPtr _create_c_op(Graph graph, NodeDef node_def, List<Tensor> inputs, Operation[] control_inputs) | |||
{ | |||
var op_desc = graph.NewOperation(node_def.Op, node_def.Name); | |||
@@ -102,6 +124,8 @@ namespace Tensorflow | |||
var status = new Status(); | |||
// Add control inputs | |||
foreach (var control_input in control_inputs) | |||
c_api.TF_AddControlInput(op_desc, control_input); | |||
// Add attrs | |||
foreach (var attr in node_def.Attr) | |||
@@ -170,8 +194,11 @@ namespace Tensorflow | |||
// inner_device_stack = default_graph._device_function_stack | |||
// var outer_context = default_graph.as_default; | |||
var outer_graph = get_default_graph(); | |||
// outer_device_stack = None | |||
Python.with(ops.control_dependencies(null), delegate | |||
{ | |||
var outer_graph = get_default_graph(); | |||
// outer_device_stack = None | |||
}); | |||
} | |||
private static int uid_number = 0; | |||
@@ -46,14 +46,13 @@ namespace TensorFlowNET.UnitTest | |||
var x = tf.Variable(10, name: "x"); | |||
var model = tf.global_variables_initializer(); | |||
using (var session = tf.Session()) | |||
{ | |||
session.run(x.initializer); | |||
session.run(model); | |||
for(int i = 0; i < 5; i++) | |||
{ | |||
var x1 = x + 1; | |||
var result = session.run(x1); | |||
x = x + 1; | |||
var result = session.run(x); | |||
print(result); | |||
} | |||
} | |||