Browse Source

Add control_inputs in Operation #141

tags/v0.1.0-Tensor
Oceania2018 6 years ago
parent
commit
8f42762f1c
12 changed files with 327 additions and 27 deletions
  1. +105
    -0
      src/TensorFlowNET.Core/Graphs/Graph.Control.cs
  2. +1
    -11
      src/TensorFlowNET.Core/Graphs/Graph.cs
  3. +80
    -0
      src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs
  4. +8
    -0
      src/TensorFlowNET.Core/Operations/Operation.Input.cs
  5. +40
    -2
      src/TensorFlowNET.Core/Operations/Operation.cs
  6. +8
    -0
      src/TensorFlowNET.Core/Operations/c_api.ops.cs
  7. +17
    -7
      src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs
  8. +25
    -0
      src/TensorFlowNET.Core/Python.cs
  9. +5
    -0
      src/TensorFlowNET.Core/Variables/RefVariable.Implicit.cs
  10. +5
    -0
      src/TensorFlowNET.Core/Variables/RefVariable.cs
  11. +30
    -3
      src/TensorFlowNET.Core/ops.py.cs
  12. +3
    -4
      test/TensorFlowNET.UnitTest/VariableTest.cs

+ 105
- 0
src/TensorFlowNET.Core/Graphs/Graph.Control.cs View File

@@ -0,0 +1,105 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Eager;

namespace Tensorflow
{
public partial class Graph
{
public Context _control_flow_context;

private Queue<_ControlDependenciesController> _graph_control_dependencies_stack = new Queue<_ControlDependenciesController>();
public Queue<_ControlDependenciesController> _control_dependencies_stack
{
get
{
return _graph_control_dependencies_stack;
}
set
{
_graph_control_dependencies_stack = value;
}
}

/// <summary>
/// For an op that takes `input_ops` as inputs, compute control inputs.
/// </summary>
/// <param name="input_ops">The data input ops for an op to be created.</param>
/// <returns>A list of control inputs for the op to be created.</returns>
private Operation[] _control_dependencies_for_inputs(Operation[] input_ops)
{
Operation[] ret = new Operation[0];

foreach(var controller in _control_dependencies_stack)
{
bool dominated = false;
// If any of the input_ops already depends on the inputs from controller,
// we say that the new op is dominated (by that input), and we therefore
// do not need to add control dependencies for this controller's inputs.
foreach(var op in input_ops)
{
if (controller.op_in_group(op))
{
dominated = true;
break;
}
}

if (!dominated)
ret = controller.control_inputs.Where(x => !input_ops.Contains(x)).ToArray();
}

return ret;
}

public _ControlDependenciesController control_dependencies(Operation[] control_inputs)
{
if (control_inputs == null)
return new _ControlDependenciesController(this, null);

var control_ops = new List<Operation>();
foreach (var c in control_inputs)
{
control_ops.Add(c);
}

return new _ControlDependenciesController(this, control_ops);
}

/// <summary>
/// Returns the current control flow context.
/// </summary>
/// <returns>A context object.</returns>
public Context _get_control_flow_context()
{
return _control_flow_context;
}

/// <summary>
/// Sets the current control flow context.
/// </summary>
/// <param name="ctx">a context object.</param>
public void _set_control_flow_context(Context ctx)
{
_control_flow_context = ctx;
}

public void _push_control_dependencies_controller(_ControlDependenciesController controller)
{
_control_dependencies_stack.Enqueue(controller);
}

public void _pop_control_dependencies_controller(_ControlDependenciesController controller)
{
_control_dependencies_stack.Dequeue();
}

public void _record_op_seen_by_control_dependencies(Operation op)
{
foreach (var controller in _control_dependencies_stack)
controller.add_op(op);
}
}
}

+ 1
- 11
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -142,19 +142,9 @@ namespace Tensorflow
return op; return op;
} }


/// <summary>
/// For an op that takes `input_ops` as inputs, compute control inputs.
/// </summary>
/// <param name="input_ops">The data input ops for an op to be created.</param>
/// <returns>A list of control inputs for the op to be created.</returns>
private Operation[] _control_dependencies_for_inputs(Operation[] input_ops)
{
return new Operation[0];
}

private void _create_op_helper(Operation op, bool compute_device = true) private void _create_op_helper(Operation op, bool compute_device = true)
{ {
_record_op_seen_by_control_dependencies(op);
} }


public void _add_op(Operation op) public void _add_op(Operation op)


+ 80
- 0
src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs View File

@@ -0,0 +1,80 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Eager;

namespace Tensorflow
{
/// <summary>
/// Context manager for `control_dependencies()`
/// </summary>
public class _ControlDependenciesController : IPython
{
private Graph _graph;
private List<Operation> _control_inputs_val;
private List<Operation> _seen_nodes;
private Queue<_ControlDependenciesController> _old_stack;
private bool _new_stack;
private Context _old_control_flow_context;

public Operation[] control_inputs => _control_inputs_val.ToArray();

public _ControlDependenciesController(Graph graph, List<Operation> control_inputs)
{
_graph = graph;
if (control_inputs == null)
{
_control_inputs_val = new List<Operation>();
_new_stack = true;
}
else
{
_control_inputs_val = control_inputs;
_new_stack = false;
}

_seen_nodes = new List<Operation>();
}

public void add_op(Operation op)
{
_seen_nodes.Add(op);
}

public bool op_in_group(Operation op)
{
return _seen_nodes.Contains(op);
}

public void __enter__()
{
if (_new_stack)
{
// Clear the control_dependencies graph.
_old_stack = _graph._control_dependencies_stack;
_graph._control_dependencies_stack = new Queue<_ControlDependenciesController>();

// Clear the control_flow_context too.
_old_control_flow_context = _graph._get_control_flow_context();
_graph._set_control_flow_context(null);
}

_graph._push_control_dependencies_controller(this);
}

public void __exit__()
{
_graph._pop_control_dependencies_controller(this);
if (_new_stack)
{
_graph._control_dependencies_stack = _old_stack;
_graph._set_control_flow_context(_old_control_flow_context);
}
}

public void Dispose()
{
}
}
}

+ 8
- 0
src/TensorFlowNET.Core/Operations/Operation.Input.cs View File

@@ -39,6 +39,14 @@ namespace Tensorflow


public int NumControlInputs => c_api.TF_OperationNumControlInputs(_handle); public int NumControlInputs => c_api.TF_OperationNumControlInputs(_handle);


public Operation[] control_inputs
{
get
{
return GetControlInputs();
}
}

public unsafe Operation[] GetControlInputs() public unsafe Operation[] GetControlInputs()
{ {
var control_inputs = new Operation[NumControlInputs]; var control_inputs = new Operation[NumControlInputs];


+ 40
- 2
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -49,15 +49,53 @@ namespace Tensorflow
c_api.TF_FinishOperation(desc, status); c_api.TF_FinishOperation(desc, status);
} }


public Operation(NodeDef node_def, Graph g, List<Tensor> inputs = null, TF_DataType[] output_types = null, object control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null)
/// <summary>
/// Creates an `Operation`.
/// </summary>
/// <param name="node_def">`node_def_pb2.NodeDef`. `NodeDef` for the `Operation`.</param>
/// <param name="g">`Graph`. The parent graph.</param>
/// <param name="inputs">list of `Tensor` objects. The inputs to this `Operation`.</param>
/// <param name="output_types">list of `DType` objects.</param>
/// <param name="control_inputs">
/// list of operations or tensors from which to have a
/// control dependency.
/// </param>
/// <param name="input_types">
/// List of `DType` objects representing the
/// types of the tensors accepted by the `Operation`. By default
/// uses `[x.dtype.base_dtype for x in inputs]`. Operations that expect
/// reference-typed inputs must specify these explicitly.
/// </param>
/// <param name="original_op"></param>
/// <param name="op_def"></param>
public Operation(NodeDef node_def, Graph g, List<Tensor> inputs = null, TF_DataType[] output_types = null, Operation[] control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null)
{ {
Graph = g; Graph = g;


// Build the list of control inputs.
var control_input_ops = new List<Operation>();
if(control_inputs != null)
{
foreach(var c in control_inputs)
{
switch (c)
{
case Operation c1:
control_input_ops.Add(c1);
break;
default:
throw new NotImplementedException($"Control input must be an Operation, a Tensor, or IndexedSlices: {c}");
}
}
}

// This will be set by self.inputs.

_id_value = Graph._next_id(); _id_value = Graph._next_id();
if(op_def == null) if(op_def == null)
op_def = g.GetOpDef(node_def.Op); op_def = g.GetOpDef(node_def.Op);


_handle = ops._create_c_op(g, node_def, inputs);
_handle = ops._create_c_op(g, node_def, inputs, control_input_ops.ToArray());


output_types = new TF_DataType[NumOutputs]; output_types = new TF_DataType[NumOutputs];




+ 8
- 0
src/TensorFlowNET.Core/Operations/c_api.ops.cs View File

@@ -34,6 +34,14 @@ namespace Tensorflow
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern void TF_AddInput(IntPtr desc, TF_Output input); public static extern void TF_AddInput(IntPtr desc, TF_Output input);


/// <summary>
/// Call once per control input to `desc`.
/// </summary>
/// <param name="desc">TF_OperationDescription*</param>
/// <param name="input">TF_Operation*</param>
[DllImport(TensorFlowLibName)]
public static extern void TF_AddControlInput(IntPtr desc, IntPtr input);

/// <summary> /// <summary>
/// For inputs that take a list of tensors. /// For inputs that take a list of tensors.
/// inputs must point to TF_Output[num_inputs]. /// inputs must point to TF_Output[num_inputs].


+ 17
- 7
src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs View File

@@ -13,9 +13,8 @@ namespace Tensorflow
{ {
name = namescope; name = namescope;


var ops_on_device = new Dictionary<string, Operation[]>();

// Sorts *inputs according to their devices. // Sorts *inputs according to their devices.
var ops_on_device = new Dictionary<string, Operation[]>();
foreach (var inp in inputs) foreach (var inp in inputs)
{ {
ops_on_device[inp.Device] = new Operation[] { inp }; ops_on_device[inp.Device] = new Operation[] { inp };
@@ -24,7 +23,9 @@ namespace Tensorflow
// 1-level tree. The root node is the returned NoOp node. // 1-level tree. The root node is the returned NoOp node.
if (ops_on_device.Count == 1) if (ops_on_device.Count == 1)
{ {
return _GroupControlDeps(ops_on_device.Keys.First(), ops_on_device.Values.First(), name);
var dev = ops_on_device.Keys.First();
var deps = ops_on_device.Values.First();
return _GroupControlDeps(dev, deps, name);
} }


// 2-level tree. The root node is the returned NoOp node. // 2-level tree. The root node is the returned NoOp node.
@@ -35,12 +36,21 @@ namespace Tensorflow


private static Operation _GroupControlDeps(string dev, Operation[] deps, string name = "") private static Operation _GroupControlDeps(string dev, Operation[] deps, string name = "")
{ {
if (string.IsNullOrEmpty(dev))
Operation result = null;

Python.with(ops.control_dependencies(deps), delegate
{ {
return gen_control_flow_ops.no_op(name);
}
if (string.IsNullOrEmpty(dev))
{
result = gen_control_flow_ops.no_op(name);
}
else
{
result = gen_control_flow_ops.no_op(name);
}
});


return null;
return result;
} }
} }
} }

+ 25
- 0
src/TensorFlowNET.Core/Python.cs View File

@@ -13,5 +13,30 @@ namespace Tensorflow
{ {
Console.WriteLine(obj.ToString()); Console.WriteLine(obj.ToString());
} }

public static void with(IPython py, Action action)
{
try
{
py.__enter__();
action();
}
catch (Exception ex)
{
throw ex;
}
finally
{
py.__exit__();
py.Dispose();
}
}
}

public interface IPython : IDisposable
{
void __enter__();

void __exit__();
} }
} }

+ 5
- 0
src/TensorFlowNET.Core/Variables/RefVariable.Implicit.cs View File

@@ -20,5 +20,10 @@ namespace Tensorflow
{ {
return var._AsTensor(); return var._AsTensor();
} }

public static implicit operator RefVariable(Tensor var)
{
return null;
}
} }
} }

+ 5
- 0
src/TensorFlowNET.Core/Variables/RefVariable.cs View File

@@ -166,5 +166,10 @@ namespace Tensorflow
// Recursively build initializer expressions for inputs. // Recursively build initializer expressions for inputs.
return op; return op;
} }

public override string ToString()
{
return $"tf.Variable '{name}' shape={shape} dtype={dtype}";
}
} }
} }

+ 30
- 3
src/TensorFlowNET.Core/ops.py.cs View File

@@ -78,7 +78,29 @@ namespace Tensorflow
} }
} }


public static unsafe IntPtr _create_c_op(Graph graph, NodeDef node_def, List<Tensor> inputs)
/// <summary>
/// Wrapper for `Graph.control_dependencies()` using the default graph.
/// </summary>
/// <param name="control_inputs"></param>
public static _ControlDependenciesController control_dependencies(Operation[] control_inputs)
{
return get_default_graph().control_dependencies(control_inputs);
}

/// <summary>
/// Creates a TF_Operation.
/// </summary>
/// <param name="graph">a `Graph`.</param>
/// <param name="node_def">`node_def_pb2.NodeDef` for the operation to create.</param>
/// <param name="inputs">
/// A list of `Tensor`s (corresponding to scalar inputs) and lists of
/// `Tensor`s (corresponding to sequence inputs, e.g. "int64 * N",
/// "list(int64)"). The length of the list should be equal to the number of
/// inputs specified by this operation's op def.
/// </param>
/// <param name="control_inputs">A list of `Operation`s to set as control dependencies.</param>
/// <returns>A wrapped TF_Operation*.</returns>
public static IntPtr _create_c_op(Graph graph, NodeDef node_def, List<Tensor> inputs, Operation[] control_inputs)
{ {
var op_desc = graph.NewOperation(node_def.Op, node_def.Name); var op_desc = graph.NewOperation(node_def.Op, node_def.Name);


@@ -102,6 +124,8 @@ namespace Tensorflow
var status = new Status(); var status = new Status();


// Add control inputs // Add control inputs
foreach (var control_input in control_inputs)
c_api.TF_AddControlInput(op_desc, control_input);


// Add attrs // Add attrs
foreach (var attr in node_def.Attr) foreach (var attr in node_def.Attr)
@@ -170,8 +194,11 @@ namespace Tensorflow
// inner_device_stack = default_graph._device_function_stack // inner_device_stack = default_graph._device_function_stack
// var outer_context = default_graph.as_default; // var outer_context = default_graph.as_default;


var outer_graph = get_default_graph();
// outer_device_stack = None
Python.with(ops.control_dependencies(null), delegate
{
var outer_graph = get_default_graph();
// outer_device_stack = None
});
} }


private static int uid_number = 0; private static int uid_number = 0;


+ 3
- 4
test/TensorFlowNET.UnitTest/VariableTest.cs View File

@@ -46,14 +46,13 @@ namespace TensorFlowNET.UnitTest
var x = tf.Variable(10, name: "x"); var x = tf.Variable(10, name: "x");


var model = tf.global_variables_initializer(); var model = tf.global_variables_initializer();

using (var session = tf.Session()) using (var session = tf.Session())
{ {
session.run(x.initializer);
session.run(model);
for(int i = 0; i < 5; i++) for(int i = 0; i < 5; i++)
{ {
var x1 = x + 1;
var result = session.run(x1);
x = x + 1;
var result = session.run(x);
print(result); print(result);
} }
} }


Loading…
Cancel
Save