Browse Source

tf.WhileContext()

tags/v0.12
Oceania2018 6 years ago
parent
commit
879067deb4
12 changed files with 458 additions and 26 deletions
  1. +6
    -0
      src/TensorFlowHub/MnistDataSet.cs
  2. +19
    -0
      src/TensorFlowNET.Core/APIs/tf.control.cs
  3. +2
    -2
      src/TensorFlowNET.Core/APIs/tf.math.cs
  4. +1
    -1
      src/TensorFlowNET.Core/Graphs/Graph.Control.cs
  5. +36
    -0
      src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs
  6. +165
    -12
      src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs
  7. +34
    -7
      src/TensorFlowNET.Core/Operations/NnOps/rnn.cs
  8. +91
    -3
      src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs
  9. +87
    -0
      src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs
  10. +3
    -0
      src/TensorFlowNET.Core/Operations/math_ops.cs
  11. +1
    -0
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  12. +13
    -1
      test/TensorFlowNET.UnitTest/control_flow_ops_test/WhileContextTestCase.cs

+ 6
- 0
src/TensorFlowHub/MnistDataSet.cs View File

@@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Text;
using NumSharp;
using Tensorflow;
@@ -21,7 +22,12 @@ namespace Tensorflow.Hub

images = images.reshape(images.shape[0], images.shape[1] * images.shape[2]);
images.astype(dataType);
// for debug np.multiply performance
var sw = new Stopwatch();
sw.Start();
images = np.multiply(images, 1.0f / 255.0f);
sw.Stop();
Console.WriteLine($"{sw.ElapsedMilliseconds}ms");
Data = images;

labels.astype(dataType);


+ 19
- 0
src/TensorFlowNET.Core/APIs/tf.control.cs View File

@@ -14,10 +14,29 @@
limitations under the License.
******************************************************************************/

using System;

namespace Tensorflow
{
public static partial class tf
{
public static Tensor while_loop(Func<Tensor, Tensor> cond, Func<Tensor, Tensor> body, Tensor[] loop_vars,
TensorShape shape_invariants = null,
int parallel_iterations = 10,
bool back_prop = true,
bool swap_memory = false,
string name = null,
int? maximum_iterations = null,
bool return_same_structure = false)
=> control_flow_ops.while_loop(cond, body, loop_vars,
shape_invariants: shape_invariants,
parallel_iterations: parallel_iterations,
back_prop: back_prop,
swap_memory: swap_memory,
name: name,
maximum_iterations: maximum_iterations,
return_same_structure: return_same_structure);

public static _ControlDependenciesController control_dependencies(Operation[] control_inputs)
=> ops.control_dependencies(control_inputs);
}


+ 2
- 2
src/TensorFlowNET.Core/APIs/tf.math.cs View File

@@ -39,8 +39,8 @@ namespace Tensorflow
public static Tensor asin(Tensor x, string name = null)
=> gen_math_ops.asin(x, name);

public static Tensor add<Tx, Ty>(Tx a, Ty b)
=> gen_math_ops.add(a, b);
public static Tensor add<Tx, Ty>(Tx a, Ty b, string name = null)
=> gen_math_ops.add(a, b, name: name);

/// <summary>
/// Computes atan of x element-wise.


+ 1
- 1
src/TensorFlowNET.Core/Graphs/Graph.Control.cs View File

@@ -33,7 +33,7 @@ namespace Tensorflow
/// </summary>
/// <param name="input_ops">The data input ops for an op to be created.</param>
/// <returns>A list of control inputs for the op to be created.</returns>
private ITensorOrOperation[] _control_dependencies_for_inputs(ITensorOrOperation[] input_ops)
public ITensorOrOperation[] _control_dependencies_for_inputs(ITensorOrOperation[] input_ops)
{
var ret = new List<ITensorOrOperation>();



+ 36
- 0
src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs View File

@@ -53,6 +53,11 @@ namespace Tensorflow.Operations
protected Stack<ControlFlowContext> _context_stack;
protected ControlFlowContext _outer_context;

/// <summary>
/// The keys are the names of tensors referenced by but external to this
/// context. Each value is the Tensor that should be used by this context to
/// access the key value (e.g. a switch output guarding a cond input value).
/// </summary>
protected Dictionary<string, ITensorOrOperation> _external_values;

public ControlFlowContext()
@@ -68,6 +73,12 @@ namespace Tensorflow.Operations
_outer_context = ops.get_default_graph()._get_control_flow_context();
if (values_def != null)
_init_values_from_proto(values_def, import_scope: import_scope);
else
{
_values = new HashSet<string>();
_external_values = new Dictionary<string, ITensorOrOperation>();
}

}

public void __enter__()
@@ -114,6 +125,27 @@ namespace Tensorflow.Operations
graph._set_control_flow_context(this);
}

protected virtual Tensor _Enter(Tensor data, string frame_name,
bool is_constant = false,
int parallel_iterations = 10,
bool use_ref = true,
bool use_input_shape = true,
string name = null)
{
Tensor result;
data = ops.internal_convert_to_tensor_or_indexed_slices(data, as_ref: true);
if (data.dtype.is_ref_dtype() && use_ref)
throw new NotImplementedException("_Enter");
else
result = gen_control_flow_ops.enter(
data, frame_name, is_constant, parallel_iterations, name: name);

if (use_input_shape)
result.SetShape(data.TensorShape);

return result;
}

/// <summary>
/// Exit this control flow context.
/// </summary>
@@ -184,6 +216,10 @@ namespace Tensorflow.Operations
return true;
}

protected virtual bool _IsInOuterContext(Operation op)
{
throw new NotImplementedException("_IsInOuterContext");
}

protected virtual void _RemoveExternalControlEdges(Operation op)
{


+ 165
- 12
src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs View File

@@ -15,8 +15,12 @@
******************************************************************************/

using System;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Operations.ControlFlows;
using Tensorflow.Util;
using static Tensorflow.Python;
using static Tensorflow.control_flow_ops;

namespace Tensorflow.Operations
{
@@ -32,10 +36,14 @@ namespace Tensorflow.Operations
bool _swap_memory;
Tensor _pivot_for_pred;
Tensor _pivot_for_body;
Tensor[] _loop_exits;
Tensor[] _loop_enters;
List<Tensor> _loop_exits;
List<Tensor> _loop_enters;
Graph _graph;
public override GradLoopState grad_state => _grad_state;
public override bool back_prop => _back_prop;

public WhileContext(int parallel_iterations = 10,
public WhileContext(int? maximum_iterations = null,
int parallel_iterations = 10,
bool back_prop = true,
bool swap_memory = false,
string name = "while_context",
@@ -49,12 +57,27 @@ namespace Tensorflow.Operations
}
else
{
__init__();
_init_from_args(maximum_iterations, parallel_iterations, back_prop, swap_memory, name);
}

_grad_state = grad_state;
}

private void _init_from_args(int? maximum_iterations,
int parallel_iterations,
bool back_prop,
bool swap_memory,
string name)
{
_name = ops.get_default_graph().unique_name(name);
_back_prop = back_prop;
_swap_memory = swap_memory;
_loop_exits = new List<Tensor>();
_loop_enters = new List<Tensor>();
_graph = ops.get_default_graph();
}

private void _init_from_proto(WhileContextDef context_def, string import_scope = null)
{
var g = ops.get_default_graph();
@@ -70,26 +93,156 @@ namespace Tensorflow.Operations
// The boolean tensor for loop termination condition.
_pivot = g.as_graph_element(ops.prepend_name_scope(context_def.PivotName, import_scope)) as Tensor;
// The list of exit tensors for loop variables.
_loop_exits = new Tensor[context_def.LoopExitNames.Count];
_loop_exits = new List<Tensor>();
foreach (var (i, exit_name) in enumerate(context_def.LoopExitNames))
_loop_exits[i] = g.as_graph_element(ops.prepend_name_scope(exit_name, import_scope)) as Tensor;
_loop_exits.Add(g.as_graph_element(ops.prepend_name_scope(exit_name, import_scope)) as Tensor);
// The list of enter tensors for loop variables.
_loop_enters = new Tensor[context_def.LoopEnterNames.Count];
_loop_enters = new List<Tensor>();
foreach (var (i, enter_name) in enumerate(context_def.LoopEnterNames))
_loop_enters[i] = g.as_graph_element(ops.prepend_name_scope(enter_name, import_scope)) as Tensor;
_loop_enters.Add(g.as_graph_element(ops.prepend_name_scope(enter_name, import_scope)) as Tensor);

__init__(values_def: context_def.ValuesDef, import_scope: import_scope);
}

public override WhileContext GetWhileContext()
/// <summary>
/// Add the loop termination condition and body to the graph.
/// </summary>
public Tensor[] BuildLoop(Func<Tensor, Tensor> pred,
Func<Tensor, Tensor> body,
Tensor[] loop_vars,
TensorShape shape_invariants,
bool return_same_structure)
{
return this;
// Keep original_loop_vars to identify which are TensorArrays
var original_loop_vars = loop_vars;
// Convert TensorArrays to their flow variables
Enter();
var(original_body_result, exit_vars) = _BuildLoop(
pred, body, original_loop_vars, loop_vars, shape_invariants);
Exit();

var flat_result = original_body_result;

var exit_vars_with_tensor_arrays = _convert_flows_to_tensorarrays(flat_result, exit_vars);
var packed_exit_vars = nest.pack_sequence_as(
structure: original_body_result,
flat_sequence: exit_vars_with_tensor_arrays);

return packed_exit_vars as Tensor[];
}

private (Tensor[], Tensor[]) _BuildLoop(Func<Tensor, Tensor> pred,
Func<Tensor, Tensor> body,
Tensor[] original_loop_vars,
Tensor[] loop_vars,
TensorShape shape_invariants)
{
var flat_loop_vars = original_loop_vars;

public override GradLoopState grad_state => _grad_state;
// Let the context know the loop variables so the loop variables
// would be added in the outer contexts properly.
_InitializeValues(loop_vars);
var real_vars = loop_vars;
Tensor[] enter_vars = null;
tf_with(ops.control_dependencies(null), delegate
{
enter_vars = real_vars.Select(x => _Enter(x,
_name,
is_constant: false,
parallel_iterations: _parallel_iterations,
use_input_shape: shape_invariants == null))
.ToArray();

public override bool back_prop => _back_prop;
foreach(var x in enter_vars)
{
x.graph.prevent_feeding(x);
if (_outer_context != null)
_outer_context.AddInnerOp(x.op);
}
});

// Finds the closest enclosing non-None control pivot.
var outer_context = _outer_context;
while (outer_context != null)
{

}

_SetShapeInvariants(real_vars, enter_vars, shape_invariants);

// Fix the control inputs and control flow context of these enter ops.
_FixControlInputsAndContext(enter_vars);
_InitializeValues(enter_vars);
_loop_enters = enter_vars.ToList();

var merge_vars = enter_vars
.Select(x => merge(new[] { x, x }))
.ToArray();

_pivot_for_pred = merge_vars[0];

// Build the graph for pred.
var merge_vars_with_tensor_arrays = _convert_flows_to_tensorarrays(flat_loop_vars, merge_vars);
// var packed_vars = nest.pack_sequence_as(original_loop_vars, merge_vars_with_tensor_arrays);
var c = ops.convert_to_tensor(pred(merge_vars_with_tensor_arrays[0]));
_pivot = gen_control_flow_ops.loop_cond(c, name: "LoopCond");
var switch_vars = merge_vars.Select(x => _SwitchRefOrTensor(x, _pivot))
.ToArray();

// Build the graph for body.
var vars_for_body = switch_vars.Select(x => _Identity(x[1])).ToArray();
// Convert TensorArray flow variables inside the context back into
// their associated TensorArrays for calling the body.
var packed_vars_for_body = _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body);
var body_result = body(packed_vars_for_body[0]);
var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION);

// Store body_result to keep track of TensorArrays returned by body
var original_body_result = new[] { body_result };
// Convert TensorArrays returned by body into their flow variables
var result = new[] { body_result };

var next_vars = new List<Tensor>();
foreach (var (m, v) in zip(merge_vars, result))
next_vars.Add(_AddNextAndBackEdge(m, v));

// Add the exit ops.
var exit_vars = switch_vars.Select(x => exit(x[0])).ToList();
_loop_exits = exit_vars;

// Exit the loop.
// ExitResult(exit_vars);
return (original_body_result, exit_vars.ToArray());
}

private void _FixControlInputsAndContext(Tensor[] enters)
{
var graph = ops.get_default_graph();
foreach(var e in enters)
{
var inp_op = e.op.inputs[0].op;
var control_inputs = graph._control_dependencies_for_inputs(new[] { inp_op });
// op for op in control_inputs if self._IsInOuterContext(op)
var outer_control_inputs = control_inputs.Where(x => _IsInOuterContext(x.op))
.Select(x => x.op)
.ToArray();
e.op._set_control_flow_context(this);
e.op._add_control_inputs(outer_control_inputs);
graph._record_op_seen_by_control_dependencies(e.op);
}
}

private void _InitializeValues(Tensor[] values)
{
_values = new HashSet<string>();
foreach(var x in values)
_values.Add(x.name);
}

public override WhileContext GetWhileContext()
{
return this;
}

public WhileContext from_proto(WhileContextDef proto, string import_scope)
{


+ 34
- 7
src/TensorFlowNET.Core/Operations/NnOps/rnn.cs View File

@@ -141,30 +141,57 @@ namespace Tensorflow.Operations
string base_name = null;
tf_with(ops.name_scope("dynamic_rnn"), scope => base_name = scope);

Func<string, TensorShape, TF_DataType, Tensor> _create_ta = (name, element_shape, dtype_) =>
Func<string, TensorShape, TF_DataType, TensorArray> _create_ta = (name, element_shape, dtype_) =>
{
new TensorArray(dtype: dtype_,
var ta = new TensorArray(dtype: dtype_,
size: time_steps,
element_shape: element_shape,
tensor_array_name: base_name + name);
throw new NotImplementedException("");
return ta;
};

bool in_graph_mode = true;
var output_ta = new List<TensorArray>();
var input_ta = new List<TensorArray>();
if (in_graph_mode)
{
foreach(var (i, out_size) in enumerate(flat_output_size))
foreach (var (i, out_size) in enumerate(flat_output_size))
{
_create_ta($"output_{i}",
output_ta.Add(_create_ta($"output_{i}",
new TensorShape(const_batch_size).concatenate(
_maybe_tensor_shape_from_tensor(out_size)),
_infer_state_dtype(dtype, state));
_infer_state_dtype(dtype, state)));
}

foreach (var (i, flat_input_i) in enumerate(flat_input))
{
input_ta.Add(_create_ta($"input_{i}",
new TensorShape(flat_input_i.dims.Skip(1).ToArray()),
flat_input_i.dtype));
}

for (int i = 0; i < input_ta.Count; i++)
{
var (ta, input_) = (input_ta[0], flat_input[0]);
}
}

// Make sure that we run at least 1 step, if necessary, to ensure
// the TensorArrays pick up the dynamic shape.
Tensor loop_bound;
if (in_graph_mode)
loop_bound = math_ops.minimum(
time_steps, math_ops.maximum(1, max_sequence_length));

/*Func<Tensor, Tensor> cond = (ctime) =>
{
return null;
};

control_flow_ops.while_loop(
cond: cond,
body = );*/

throw new NotImplementedException("");
}



+ 91
- 3
src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs View File

@@ -26,6 +26,44 @@ namespace Tensorflow
{
public class control_flow_ops
{
public static Tensor _AddNextAndBackEdge(Tensor m, Tensor v, bool enforce_shape_invariant = true)
{
v = ops.convert_to_tensor(v);
v = _NextIteration(v);
if (enforce_shape_invariant)
_EnforceShapeInvariant(m, v);
m.op._update_input(1, v);
return v;
}

/// <summary>
/// Check if the shapes of the loops variables are invariants.
/// </summary>
/// <param name="merge_var"></param>
/// <param name="next_var"></param>
public static void _EnforceShapeInvariant(Tensor merge_var, Tensor next_var)
{

}

public static Tensor exit(Tensor data, string name = null)
{
data = ops.internal_convert_to_tensor_or_indexed_slices(data, as_ref: true);
if (data.dtype.is_ref_dtype())
return gen_control_flow_ops.ref_exit(data, name: name);
else
return gen_control_flow_ops._exit(data, name: name);
}

public static Tensor _NextIteration(Tensor data, string name = null)
{
data = ops.internal_convert_to_tensor_or_indexed_slices(data, as_ref: true);
if (data.dtype.is_ref_dtype())
return gen_control_flow_ops.ref_next_iteration(data, name: name);
else
return gen_control_flow_ops.next_iteration(data, name: name);
}

public static Operation Assert(Tensor condition, object[] data, int? summarize = null, string name = null)
{
return tf_with(ops.name_scope(name, "Assert", new { condition, data }), scope =>
@@ -213,6 +251,14 @@ namespace Tensorflow
return gen_array_ops.identity(data, name: name);
}

public static void _SetShapeInvariants(Tensor[] input_vars, Tensor[] enter_vars, TensorShape shapes = null)
{
if (shapes == null)
return;

throw new NotImplementedException("_SetShapeInvariants");
}

/// <summary>
/// Forwards `data` to an output determined by `pred`.
/// If `pred` is false, the `data` input is forwarded to the first output.
@@ -516,10 +562,52 @@ namespace Tensorflow
throw new NotImplementedException("ZerosLikeOutsideLoop");
}

// TODO
public static void while_loop(Func<Tensor, Tensor> func, Func<Tensor, Tensor> func1, Tensor[] tensors, int? i)
/// <summary>
/// Repeat `body` while the condition `cond` is true.
/// </summary>
/// <param name="cond"></param>
/// <param name="body"></param>
/// <param name="loop_vars"></param>
/// <param name="i"></param>
public static Tensor while_loop(Func<Tensor, Tensor> cond, Func<Tensor, Tensor> body, Tensor[] loop_vars,
TensorShape shape_invariants = null,
int parallel_iterations = 10,
bool back_prop = true,
bool swap_memory = false,
string name = null,
int? maximum_iterations = null,
bool return_same_structure = false)
{
throw new NotImplementedException();
tf_with(ops.name_scope(name, "while", loop_vars), scope =>
{
if (loop_vars == null || loop_vars.Length == 0)
throw new ValueError("No loop variables provided");
if (cond == null)
throw new ValueError("cond must be callable.");
if (body == null)
throw new ValueError("body must be callable.");
if (parallel_iterations < 1)
throw new ValueError("parallel_iterations must be a positive integer.");

var loop_context = new WhileContext(
maximum_iterations: maximum_iterations,
parallel_iterations: parallel_iterations,
back_prop: back_prop,
swap_memory: swap_memory);

if (loop_context.outer_context == null)
ops.add_to_collection(ops.GraphKeys.WHILE_CONTEXT, loop_context);

var results = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants,
return_same_structure);

if (maximum_iterations != null)
return results[1];
else
return results[0];
});

throw new NotImplementedException("while_loop");
}

}


+ 87
- 0
src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs View File

@@ -20,6 +20,93 @@ namespace Tensorflow
{
public static OpDefLibrary _op_def_lib = new OpDefLibrary();

/// <summary>
/// Creates or finds a child frame, and makes `data` available to the child frame.
/// </summary>
/// <param name="data"></param>
/// <param name="frame_name"></param>
/// <param name="is_constant"></param>
/// <param name="parallel_iterations"></param>
/// <param name="name"></param>
/// <returns></returns>
public static Tensor enter(Tensor data, string frame_name = "frame_name", bool is_constant = false, int parallel_iterations = 10, string name = null)
{
var _op = _op_def_lib._apply_op_helper("Enter", name, new
{
data,
frame_name,
is_constant,
parallel_iterations
});

return _op.output;
}

/// <summary>
/// Forwards the input to the output.
/// </summary>
/// <param name="input"></param>
/// <param name="name"></param>
/// <returns></returns>
public static Tensor loop_cond(Tensor input, string name = null)
{
var _op = _op_def_lib._apply_op_helper("LoopCond", name, new { input });

return _op.output;
}

/// <summary>
/// Makes its input available to the next iteration.
/// </summary>
/// <param name="data"></param>
/// <param name="name"></param>
/// <returns></returns>
public static Tensor ref_next_iteration(Tensor data, string name = null)
{
var _op = _op_def_lib._apply_op_helper("RefNextIteration", name, new { data });

return _op;
}

/// <summary>
/// Makes its input available to the next iteration.
/// </summary>
/// <param name="data"></param>
/// <param name="name"></param>
/// <returns></returns>
public static Tensor next_iteration(Tensor data, string name = null)
{
var _op = _op_def_lib._apply_op_helper("NextIteration", name, new { data });

return _op;
}

/// <summary>
/// Exits the current frame to its parent frame.
/// </summary>
/// <param name="data"></param>
/// <param name="name"></param>
/// <returns></returns>
public static Tensor ref_exit(Tensor data, string name = null)
{
var _op = _op_def_lib._apply_op_helper("RefExit", name, new { data });

return _op;
}

/// <summary>
/// Exits the current frame to its parent frame.
/// </summary>
/// <param name="data"></param>
/// <param name="name"></param>
/// <returns></returns>
public static Tensor _exit(Tensor data, string name = null)
{
var _op = _op_def_lib._apply_op_helper("Exit", name, new { data });

return _op;
}

public static Operation no_op(string name = null)
{
var _op = _op_def_lib._apply_op_helper("NoOp", name, null);


+ 3
- 0
src/TensorFlowNET.Core/Operations/math_ops.cs View File

@@ -516,6 +516,9 @@ namespace Tensorflow
});
}

public static Tensor minimum<Tx, Ty>(Tx x, Ty y, string name = null)
=> gen_math_ops.minimum(x, y, name: name);

public static Tensor maximum<Tx, Ty>(Tx x, Ty y, string name = null)
=> gen_math_ops.maximum(x, y, name: name);



+ 1
- 0
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -416,5 +416,6 @@ namespace Tensorflow
}
}

public int tensor_int_val { get; set; }
}
}

+ 13
- 1
test/TensorFlowNET.UnitTest/control_flow_ops_test/WhileContextTestCase.cs View File

@@ -8,6 +8,18 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test
[TestClass]
public class WhileContextTestCase : PythonTest
{
/// <summary>
/// https://www.tensorflow.org/api_docs/python/tf/while_loop
/// </summary>
[TestMethod]
public void SimpleWhileLoop()
{
var i = constant_op.constant(0, name: "i");
var c = new Func<Tensor, Tensor>(x => tf.less(x, 10, name: "c"));
var b = new Func<Tensor, Tensor>(x => tf.add(x, 1, name: "c"));
var r = control_flow_ops.while_loop(c, b, new[] { i });
}
private void _testWhileContextHelper(int? maximum_iterations = null)
{
// TODO: implement missing code dependencies
@@ -17,7 +29,7 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test
var c = new Func<Tensor, Tensor>(x => gen_math_ops.less(x, 10, name: "c"));
var b = new Func<Tensor, Tensor>(x => gen_math_ops.add(x, 1, name: "c"));
control_flow_ops.while_loop(
c, b, new[] { i }, maximum_iterations);
c, b, new[] { i }, maximum_iterations: maximum_iterations);
foreach (Operation op in sess.graph.get_operations())
{
var control_flow_context = op._get_control_flow_context();


Loading…
Cancel
Save