@@ -0,0 +1,105 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Linq; | |||||
using System.Text; | |||||
using Tensorflow.Eager; | |||||
namespace Tensorflow | |||||
{ | |||||
public partial class Graph | |||||
{ | |||||
public Context _control_flow_context; | |||||
private Queue<_ControlDependenciesController> _graph_control_dependencies_stack = new Queue<_ControlDependenciesController>(); | |||||
public Queue<_ControlDependenciesController> _control_dependencies_stack | |||||
{ | |||||
get | |||||
{ | |||||
return _graph_control_dependencies_stack; | |||||
} | |||||
set | |||||
{ | |||||
_graph_control_dependencies_stack = value; | |||||
} | |||||
} | |||||
/// <summary> | |||||
/// For an op that takes `input_ops` as inputs, compute control inputs. | |||||
/// </summary> | |||||
/// <param name="input_ops">The data input ops for an op to be created.</param> | |||||
/// <returns>A list of control inputs for the op to be created.</returns> | |||||
private Operation[] _control_dependencies_for_inputs(Operation[] input_ops) | |||||
{ | |||||
Operation[] ret = new Operation[0]; | |||||
foreach(var controller in _control_dependencies_stack) | |||||
{ | |||||
bool dominated = false; | |||||
// If any of the input_ops already depends on the inputs from controller, | |||||
// we say that the new op is dominated (by that input), and we therefore | |||||
// do not need to add control dependencies for this controller's inputs. | |||||
foreach(var op in input_ops) | |||||
{ | |||||
if (controller.op_in_group(op)) | |||||
{ | |||||
dominated = true; | |||||
break; | |||||
} | |||||
} | |||||
if (!dominated) | |||||
ret = controller.control_inputs.Where(x => !input_ops.Contains(x)).ToArray(); | |||||
} | |||||
return ret; | |||||
} | |||||
public _ControlDependenciesController control_dependencies(Operation[] control_inputs) | |||||
{ | |||||
if (control_inputs == null) | |||||
return new _ControlDependenciesController(this, null); | |||||
var control_ops = new List<Operation>(); | |||||
foreach (var c in control_inputs) | |||||
{ | |||||
control_ops.Add(c); | |||||
} | |||||
return new _ControlDependenciesController(this, control_ops); | |||||
} | |||||
/// <summary> | |||||
/// Returns the current control flow context. | |||||
/// </summary> | |||||
/// <returns>A context object.</returns> | |||||
public Context _get_control_flow_context() | |||||
{ | |||||
return _control_flow_context; | |||||
} | |||||
/// <summary> | |||||
/// Sets the current control flow context. | |||||
/// </summary> | |||||
/// <param name="ctx">a context object.</param> | |||||
public void _set_control_flow_context(Context ctx) | |||||
{ | |||||
_control_flow_context = ctx; | |||||
} | |||||
public void _push_control_dependencies_controller(_ControlDependenciesController controller) | |||||
{ | |||||
_control_dependencies_stack.Enqueue(controller); | |||||
} | |||||
public void _pop_control_dependencies_controller(_ControlDependenciesController controller) | |||||
{ | |||||
_control_dependencies_stack.Dequeue(); | |||||
} | |||||
public void _record_op_seen_by_control_dependencies(Operation op) | |||||
{ | |||||
foreach (var controller in _control_dependencies_stack) | |||||
controller.add_op(op); | |||||
} | |||||
} | |||||
} |
@@ -142,19 +142,9 @@ namespace Tensorflow | |||||
return op; | return op; | ||||
} | } | ||||
/// <summary> | |||||
/// For an op that takes `input_ops` as inputs, compute control inputs. | |||||
/// </summary> | |||||
/// <param name="input_ops">The data input ops for an op to be created.</param> | |||||
/// <returns>A list of control inputs for the op to be created.</returns> | |||||
private Operation[] _control_dependencies_for_inputs(Operation[] input_ops) | |||||
{ | |||||
return new Operation[0]; | |||||
} | |||||
private void _create_op_helper(Operation op, bool compute_device = true) | private void _create_op_helper(Operation op, bool compute_device = true) | ||||
{ | { | ||||
_record_op_seen_by_control_dependencies(op); | |||||
} | } | ||||
public void _add_op(Operation op) | public void _add_op(Operation op) | ||||
@@ -0,0 +1,80 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using Tensorflow.Eager; | |||||
namespace Tensorflow | |||||
{ | |||||
/// <summary> | |||||
/// Context manager for `control_dependencies()` | |||||
/// </summary> | |||||
public class _ControlDependenciesController : IPython | |||||
{ | |||||
private Graph _graph; | |||||
private List<Operation> _control_inputs_val; | |||||
private List<Operation> _seen_nodes; | |||||
private Queue<_ControlDependenciesController> _old_stack; | |||||
private bool _new_stack; | |||||
private Context _old_control_flow_context; | |||||
public Operation[] control_inputs => _control_inputs_val.ToArray(); | |||||
public _ControlDependenciesController(Graph graph, List<Operation> control_inputs) | |||||
{ | |||||
_graph = graph; | |||||
if (control_inputs == null) | |||||
{ | |||||
_control_inputs_val = new List<Operation>(); | |||||
_new_stack = true; | |||||
} | |||||
else | |||||
{ | |||||
_control_inputs_val = control_inputs; | |||||
_new_stack = false; | |||||
} | |||||
_seen_nodes = new List<Operation>(); | |||||
} | |||||
public void add_op(Operation op) | |||||
{ | |||||
_seen_nodes.Add(op); | |||||
} | |||||
public bool op_in_group(Operation op) | |||||
{ | |||||
return _seen_nodes.Contains(op); | |||||
} | |||||
public void __enter__() | |||||
{ | |||||
if (_new_stack) | |||||
{ | |||||
// Clear the control_dependencies graph. | |||||
_old_stack = _graph._control_dependencies_stack; | |||||
_graph._control_dependencies_stack = new Queue<_ControlDependenciesController>(); | |||||
// Clear the control_flow_context too. | |||||
_old_control_flow_context = _graph._get_control_flow_context(); | |||||
_graph._set_control_flow_context(null); | |||||
} | |||||
_graph._push_control_dependencies_controller(this); | |||||
} | |||||
public void __exit__() | |||||
{ | |||||
_graph._pop_control_dependencies_controller(this); | |||||
if (_new_stack) | |||||
{ | |||||
_graph._control_dependencies_stack = _old_stack; | |||||
_graph._set_control_flow_context(_old_control_flow_context); | |||||
} | |||||
} | |||||
public void Dispose() | |||||
{ | |||||
} | |||||
} | |||||
} |
@@ -39,6 +39,14 @@ namespace Tensorflow | |||||
public int NumControlInputs => c_api.TF_OperationNumControlInputs(_handle); | public int NumControlInputs => c_api.TF_OperationNumControlInputs(_handle); | ||||
public Operation[] control_inputs | |||||
{ | |||||
get | |||||
{ | |||||
return GetControlInputs(); | |||||
} | |||||
} | |||||
public unsafe Operation[] GetControlInputs() | public unsafe Operation[] GetControlInputs() | ||||
{ | { | ||||
var control_inputs = new Operation[NumControlInputs]; | var control_inputs = new Operation[NumControlInputs]; | ||||
@@ -49,15 +49,53 @@ namespace Tensorflow | |||||
c_api.TF_FinishOperation(desc, status); | c_api.TF_FinishOperation(desc, status); | ||||
} | } | ||||
public Operation(NodeDef node_def, Graph g, List<Tensor> inputs = null, TF_DataType[] output_types = null, object control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null) | |||||
/// <summary> | |||||
/// Creates an `Operation`. | |||||
/// </summary> | |||||
/// <param name="node_def">`node_def_pb2.NodeDef`. `NodeDef` for the `Operation`.</param> | |||||
/// <param name="g">`Graph`. The parent graph.</param> | |||||
/// <param name="inputs">list of `Tensor` objects. The inputs to this `Operation`.</param> | |||||
/// <param name="output_types">list of `DType` objects.</param> | |||||
/// <param name="control_inputs"> | |||||
/// list of operations or tensors from which to have a | |||||
/// control dependency. | |||||
/// </param> | |||||
/// <param name="input_types"> | |||||
/// List of `DType` objects representing the | |||||
/// types of the tensors accepted by the `Operation`. By default | |||||
/// uses `[x.dtype.base_dtype for x in inputs]`. Operations that expect | |||||
/// reference-typed inputs must specify these explicitly. | |||||
/// </param> | |||||
/// <param name="original_op"></param> | |||||
/// <param name="op_def"></param> | |||||
public Operation(NodeDef node_def, Graph g, List<Tensor> inputs = null, TF_DataType[] output_types = null, Operation[] control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null) | |||||
{ | { | ||||
Graph = g; | Graph = g; | ||||
// Build the list of control inputs. | |||||
var control_input_ops = new List<Operation>(); | |||||
if(control_inputs != null) | |||||
{ | |||||
foreach(var c in control_inputs) | |||||
{ | |||||
switch (c) | |||||
{ | |||||
case Operation c1: | |||||
control_input_ops.Add(c1); | |||||
break; | |||||
default: | |||||
throw new NotImplementedException($"Control input must be an Operation, a Tensor, or IndexedSlices: {c}"); | |||||
} | |||||
} | |||||
} | |||||
// This will be set by self.inputs. | |||||
_id_value = Graph._next_id(); | _id_value = Graph._next_id(); | ||||
if(op_def == null) | if(op_def == null) | ||||
op_def = g.GetOpDef(node_def.Op); | op_def = g.GetOpDef(node_def.Op); | ||||
_handle = ops._create_c_op(g, node_def, inputs); | |||||
_handle = ops._create_c_op(g, node_def, inputs, control_input_ops.ToArray()); | |||||
output_types = new TF_DataType[NumOutputs]; | output_types = new TF_DataType[NumOutputs]; | ||||
@@ -34,6 +34,14 @@ namespace Tensorflow | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern void TF_AddInput(IntPtr desc, TF_Output input); | public static extern void TF_AddInput(IntPtr desc, TF_Output input); | ||||
/// <summary> | |||||
/// Call once per control input to `desc`. | |||||
/// </summary> | |||||
/// <param name="desc">TF_OperationDescription*</param> | |||||
/// <param name="input">TF_Operation*</param> | |||||
[DllImport(TensorFlowLibName)] | |||||
public static extern void TF_AddControlInput(IntPtr desc, IntPtr input); | |||||
/// <summary> | /// <summary> | ||||
/// For inputs that take a list of tensors. | /// For inputs that take a list of tensors. | ||||
/// inputs must point to TF_Output[num_inputs]. | /// inputs must point to TF_Output[num_inputs]. | ||||
@@ -13,9 +13,8 @@ namespace Tensorflow | |||||
{ | { | ||||
name = namescope; | name = namescope; | ||||
var ops_on_device = new Dictionary<string, Operation[]>(); | |||||
// Sorts *inputs according to their devices. | // Sorts *inputs according to their devices. | ||||
var ops_on_device = new Dictionary<string, Operation[]>(); | |||||
foreach (var inp in inputs) | foreach (var inp in inputs) | ||||
{ | { | ||||
ops_on_device[inp.Device] = new Operation[] { inp }; | ops_on_device[inp.Device] = new Operation[] { inp }; | ||||
@@ -24,7 +23,9 @@ namespace Tensorflow | |||||
// 1-level tree. The root node is the returned NoOp node. | // 1-level tree. The root node is the returned NoOp node. | ||||
if (ops_on_device.Count == 1) | if (ops_on_device.Count == 1) | ||||
{ | { | ||||
return _GroupControlDeps(ops_on_device.Keys.First(), ops_on_device.Values.First(), name); | |||||
var dev = ops_on_device.Keys.First(); | |||||
var deps = ops_on_device.Values.First(); | |||||
return _GroupControlDeps(dev, deps, name); | |||||
} | } | ||||
// 2-level tree. The root node is the returned NoOp node. | // 2-level tree. The root node is the returned NoOp node. | ||||
@@ -35,12 +36,21 @@ namespace Tensorflow | |||||
private static Operation _GroupControlDeps(string dev, Operation[] deps, string name = "") | private static Operation _GroupControlDeps(string dev, Operation[] deps, string name = "") | ||||
{ | { | ||||
if (string.IsNullOrEmpty(dev)) | |||||
Operation result = null; | |||||
Python.with(ops.control_dependencies(deps), delegate | |||||
{ | { | ||||
return gen_control_flow_ops.no_op(name); | |||||
} | |||||
if (string.IsNullOrEmpty(dev)) | |||||
{ | |||||
result = gen_control_flow_ops.no_op(name); | |||||
} | |||||
else | |||||
{ | |||||
result = gen_control_flow_ops.no_op(name); | |||||
} | |||||
}); | |||||
return null; | |||||
return result; | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -13,5 +13,30 @@ namespace Tensorflow | |||||
{ | { | ||||
Console.WriteLine(obj.ToString()); | Console.WriteLine(obj.ToString()); | ||||
} | } | ||||
public static void with(IPython py, Action action) | |||||
{ | |||||
try | |||||
{ | |||||
py.__enter__(); | |||||
action(); | |||||
} | |||||
catch (Exception ex) | |||||
{ | |||||
throw ex; | |||||
} | |||||
finally | |||||
{ | |||||
py.__exit__(); | |||||
py.Dispose(); | |||||
} | |||||
} | |||||
} | |||||
public interface IPython : IDisposable | |||||
{ | |||||
void __enter__(); | |||||
void __exit__(); | |||||
} | } | ||||
} | } |
@@ -20,5 +20,10 @@ namespace Tensorflow | |||||
{ | { | ||||
return var._AsTensor(); | return var._AsTensor(); | ||||
} | } | ||||
public static implicit operator RefVariable(Tensor var) | |||||
{ | |||||
return null; | |||||
} | |||||
} | } | ||||
} | } |
@@ -166,5 +166,10 @@ namespace Tensorflow | |||||
// Recursively build initializer expressions for inputs. | // Recursively build initializer expressions for inputs. | ||||
return op; | return op; | ||||
} | } | ||||
public override string ToString() | |||||
{ | |||||
return $"tf.Variable '{name}' shape={shape} dtype={dtype}"; | |||||
} | |||||
} | } | ||||
} | } |
@@ -78,7 +78,29 @@ namespace Tensorflow | |||||
} | } | ||||
} | } | ||||
public static unsafe IntPtr _create_c_op(Graph graph, NodeDef node_def, List<Tensor> inputs) | |||||
/// <summary> | |||||
/// Wrapper for `Graph.control_dependencies()` using the default graph. | |||||
/// </summary> | |||||
/// <param name="control_inputs"></param> | |||||
public static _ControlDependenciesController control_dependencies(Operation[] control_inputs) | |||||
{ | |||||
return get_default_graph().control_dependencies(control_inputs); | |||||
} | |||||
/// <summary> | |||||
/// Creates a TF_Operation. | |||||
/// </summary> | |||||
/// <param name="graph">a `Graph`.</param> | |||||
/// <param name="node_def">`node_def_pb2.NodeDef` for the operation to create.</param> | |||||
/// <param name="inputs"> | |||||
/// A list of `Tensor`s (corresponding to scalar inputs) and lists of | |||||
/// `Tensor`s (corresponding to sequence inputs, e.g. "int64 * N", | |||||
/// "list(int64)"). The length of the list should be equal to the number of | |||||
/// inputs specified by this operation's op def. | |||||
/// </param> | |||||
/// <param name="control_inputs">A list of `Operation`s to set as control dependencies.</param> | |||||
/// <returns>A wrapped TF_Operation*.</returns> | |||||
public static IntPtr _create_c_op(Graph graph, NodeDef node_def, List<Tensor> inputs, Operation[] control_inputs) | |||||
{ | { | ||||
var op_desc = graph.NewOperation(node_def.Op, node_def.Name); | var op_desc = graph.NewOperation(node_def.Op, node_def.Name); | ||||
@@ -102,6 +124,8 @@ namespace Tensorflow | |||||
var status = new Status(); | var status = new Status(); | ||||
// Add control inputs | // Add control inputs | ||||
foreach (var control_input in control_inputs) | |||||
c_api.TF_AddControlInput(op_desc, control_input); | |||||
// Add attrs | // Add attrs | ||||
foreach (var attr in node_def.Attr) | foreach (var attr in node_def.Attr) | ||||
@@ -170,8 +194,11 @@ namespace Tensorflow | |||||
// inner_device_stack = default_graph._device_function_stack | // inner_device_stack = default_graph._device_function_stack | ||||
// var outer_context = default_graph.as_default; | // var outer_context = default_graph.as_default; | ||||
var outer_graph = get_default_graph(); | |||||
// outer_device_stack = None | |||||
Python.with(ops.control_dependencies(null), delegate | |||||
{ | |||||
var outer_graph = get_default_graph(); | |||||
// outer_device_stack = None | |||||
}); | |||||
} | } | ||||
private static int uid_number = 0; | private static int uid_number = 0; | ||||
@@ -46,14 +46,13 @@ namespace TensorFlowNET.UnitTest | |||||
var x = tf.Variable(10, name: "x"); | var x = tf.Variable(10, name: "x"); | ||||
var model = tf.global_variables_initializer(); | var model = tf.global_variables_initializer(); | ||||
using (var session = tf.Session()) | using (var session = tf.Session()) | ||||
{ | { | ||||
session.run(x.initializer); | |||||
session.run(model); | |||||
for(int i = 0; i < 5; i++) | for(int i = 0; i < 5; i++) | ||||
{ | { | ||||
var x1 = x + 1; | |||||
var result = session.run(x1); | |||||
x = x + 1; | |||||
var result = session.run(x); | |||||
print(result); | print(result); | ||||
} | } | ||||
} | } | ||||