Browse Source

change WhileContext maximum_iterations to Tensor.

tags/v0.12
Oceania2018 6 years ago
parent
commit
ed9a8c88a5
1 changed files with 93 additions and 74 deletions
  1. +93
    -74
      src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs

+ 93
- 74
src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs View File

@@ -42,7 +42,7 @@ namespace Tensorflow.Operations
public override GradLoopState grad_state => _grad_state;
public override bool back_prop => _back_prop;

public WhileContext(int? maximum_iterations = null,
public WhileContext(Tensor maximum_iterations = null,
int parallel_iterations = 10,
bool back_prop = true,
bool swap_memory = false,
@@ -64,7 +64,7 @@ namespace Tensorflow.Operations
_grad_state = grad_state;
}

private void _init_from_args(int? maximum_iterations,
private void _init_from_args(Tensor maximum_iterations,
int parallel_iterations,
bool back_prop,
bool swap_memory,
@@ -107,9 +107,9 @@ namespace Tensorflow.Operations
/// <summary>
/// Add the loop termination condition and body to the graph.
/// </summary>
public Tensor[] BuildLoop(Func<Tensor, Tensor> pred,
Func<Tensor, Tensor> body,
Tensor[] loop_vars,
internal Tensor[] BuildLoop<TItem>(Func<Tensor, TItem, Tensor> pred,
Func<Tensor, TItem, LoopVar<TItem>> body,
TItem loop_vars,
TensorShape shape_invariants,
bool return_same_structure)
{
@@ -131,88 +131,107 @@ namespace Tensorflow.Operations
return packed_exit_vars as Tensor[];
}

private (Tensor[], Tensor[]) _BuildLoop(Func<Tensor, Tensor> pred,
Func<Tensor, Tensor> body,
Tensor[] original_loop_vars,
Tensor[] loop_vars,
private Tensor _convert_tensorarray_to_flow<TItem>(TItem tensor_or_tensor_array)
{
if (tensor_or_tensor_array is TensorArray tensor_array)
return tensor_array.flow;
else if (tensor_or_tensor_array is Tensor tensor)
return tensor;

throw new NotImplementedException("_convert_tensorarray_to_flow");
}

private (Tensor[], Tensor[]) _BuildLoop<TItem>(Func<Tensor, TItem, Tensor> pred,
Func<Tensor, TItem, LoopVar<TItem>> body,
TItem original_loop_vars,
TItem loop_vars,
TensorShape shape_invariants)
{
var flat_loop_vars = original_loop_vars;

// Convert TensorArrays to their flow variables
var loop_vars_tensor = nest.map_structure(
_convert_tensorarray_to_flow,
nest.flatten(loop_vars));

// Let the context know the loop variables so the loop variables
// would be added in the outer contexts properly.
_InitializeValues(loop_vars);
var real_vars = loop_vars;
Tensor[] enter_vars = null;
tf_with(ops.control_dependencies(null), delegate
if (loop_vars is Tensor[] real_vars)
{
enter_vars = real_vars.Select(x => _Enter(x,
_name,
is_constant: false,
parallel_iterations: _parallel_iterations,
use_input_shape: shape_invariants == null))
.ToArray();

foreach(var x in enter_vars)
_InitializeValues(real_vars);
Tensor[] enter_vars = null;
tf_with(ops.control_dependencies(null), delegate
{
enter_vars = real_vars.Select(x => _Enter(x,
_name,
is_constant: false,
parallel_iterations: _parallel_iterations,
use_input_shape: shape_invariants == null))
.ToArray();

foreach (var x in enter_vars)
{
x.graph.prevent_feeding(x);
if (_outer_context != null)
_outer_context.AddInnerOp(x.op);
}
});

// Finds the closest enclosing non-None control pivot.
var outer_context = _outer_context;
while (outer_context != null)
{
x.graph.prevent_feeding(x);
if (_outer_context != null)
_outer_context.AddInnerOp(x.op);

}
});

// Finds the closest enclosing non-None control pivot.
var outer_context = _outer_context;
while (outer_context != null)
{
_SetShapeInvariants(real_vars, enter_vars, shape_invariants);

// Fix the control inputs and control flow context of these enter ops.
_FixControlInputsAndContext(enter_vars);
_InitializeValues(enter_vars);
_loop_enters = enter_vars.ToList();

var merge_vars = enter_vars
.Select(x => merge(new[] { x, x }))
.ToArray();

_pivot_for_pred = merge_vars[0];

// Build the graph for pred.
var merge_vars_with_tensor_arrays = _convert_flows_to_tensorarrays(flat_loop_vars, merge_vars);
// var packed_vars = nest.pack_sequence_as(original_loop_vars, merge_vars_with_tensor_arrays);
var c = ops.convert_to_tensor(pred(merge_vars_with_tensor_arrays[0], default(TItem)));
_pivot = gen_control_flow_ops.loop_cond(c, name: "LoopCond");
var switch_vars = merge_vars.Select(x => _SwitchRefOrTensor(x, _pivot))
.ToArray();

// Build the graph for body.
var vars_for_body = switch_vars.Select(x => _Identity(x[1])).ToArray();
// Convert TensorArray flow variables inside the context back into
// their associated TensorArrays for calling the body.
var packed_vars_for_body = _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body);
/*var body_result = body(packed_vars_for_body[0]);
var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION);

// Store body_result to keep track of TensorArrays returned by body
var original_body_result = new[] { body_result };
// Convert TensorArrays returned by body into their flow variables
var result = new[] { body_result };

var next_vars = new List<Tensor>();
foreach (var (m, v) in zip(merge_vars, result))
next_vars.Add(_AddNextAndBackEdge(m, v));

// Add the exit ops.
var exit_vars = switch_vars.Select(x => exit(x[0])).ToList();
_loop_exits = exit_vars;

// Exit the loop.
// ExitResult(exit_vars);
return (original_body_result, exit_vars.ToArray());*/
}

_SetShapeInvariants(real_vars, enter_vars, shape_invariants);

// Fix the control inputs and control flow context of these enter ops.
_FixControlInputsAndContext(enter_vars);
_InitializeValues(enter_vars);
_loop_enters = enter_vars.ToList();

var merge_vars = enter_vars
.Select(x => merge(new[] { x, x }))
.ToArray();

_pivot_for_pred = merge_vars[0];

// Build the graph for pred.
var merge_vars_with_tensor_arrays = _convert_flows_to_tensorarrays(flat_loop_vars, merge_vars);
// var packed_vars = nest.pack_sequence_as(original_loop_vars, merge_vars_with_tensor_arrays);
var c = ops.convert_to_tensor(pred(merge_vars_with_tensor_arrays[0]));
_pivot = gen_control_flow_ops.loop_cond(c, name: "LoopCond");
var switch_vars = merge_vars.Select(x => _SwitchRefOrTensor(x, _pivot))
.ToArray();

// Build the graph for body.
var vars_for_body = switch_vars.Select(x => _Identity(x[1])).ToArray();
// Convert TensorArray flow variables inside the context back into
// their associated TensorArrays for calling the body.
var packed_vars_for_body = _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body);
var body_result = body(packed_vars_for_body[0]);
var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION);

// Store body_result to keep track of TensorArrays returned by body
var original_body_result = new[] { body_result };
// Convert TensorArrays returned by body into their flow variables
var result = new[] { body_result };

var next_vars = new List<Tensor>();
foreach (var (m, v) in zip(merge_vars, result))
next_vars.Add(_AddNextAndBackEdge(m, v));

// Add the exit ops.
var exit_vars = switch_vars.Select(x => exit(x[0])).ToList();
_loop_exits = exit_vars;

// Exit the loop.
// ExitResult(exit_vars);
return (original_body_result, exit_vars.ToArray());
throw new NotImplementedException("");
}

private void _FixControlInputsAndContext(Tensor[] enters)


Loading…
Cancel
Save