Browse Source

MaybeCreateControlFlowState

tags/v0.12
Oceania2018 6 years ago
parent
commit
59b7eb0365
5 changed files with 232 additions and 108 deletions
  1. +95
    -83
      src/TensorFlowNET.Core/Gradients/gradients_util.cs
  2. +4
    -3
      src/TensorFlowNET.Core/Layers/Layer.cs
  3. +38
    -20
      src/TensorFlowNET.Core/Operations/control_flow_ops.cs
  4. +93
    -2
      src/TensorFlowNET.Core/Operations/control_flow_util.py.cs
  5. +2
    -0
      src/TensorFlowNET.Core/Operations/math_ops.cs

+ 95
- 83
src/TensorFlowNET.Core/Gradients/gradients_util.cs View File

@@ -55,6 +55,9 @@ namespace Tensorflow
* is more than one.
**/
var grads = new Dictionary<string, List<List<Tensor>>>();
Operation[] reachable_to_ops = null;
ControlFlowState loop_state = null;
Dictionary<string, int> pending_count = null;

tf_with(ops.name_scope(name, "gradients",
values: ys.Concat(xs).Concat(stop_gradients).Concat(grad_ys)), scope =>
@@ -81,7 +84,7 @@ namespace Tensorflow
var to_ops = ys.Select(x => x.op).ToList();
var from_ops = xs.Select(x => x.op).ToList();
var stop_gradient_ops = stop_gradients.Select(x => x.op).ToList();
var (reachable_to_ops, pending_count, loop_state) = _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, new List<object>(), xs);
(reachable_to_ops, pending_count, loop_state) = _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, new List<object>(), xs);

// Add the initial gradients for the ys.
foreach (var (y, grad_y) in zip(ys, grad_ys))
@@ -120,126 +123,135 @@ namespace Tensorflow
{
// generate gradient subgraph for op.
var op = queue.Dequeue();
if(op.name == "rnn/while/basic_rnn_cell/Tanh")
if(op.name == "rnn/while/Exit")
{

}
_maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops);
//if (loop_state != null)
//loop_state.EnterGradWhileContext(op, before: true);
var out_grads = _AggregatedGrads(grads, op, gradient_uid, loop_state, aggregation_method);

Tensor[] in_grads = null;
var is_partitioned_call = _IsPartitionedCall(op);
var is_func_call = false;
var has_out_grads = out_grads.Exists(x => x != null);
if (has_out_grads && !stop_ops.Contains(op))
{
// A grad_fn must be defined, either as a function or as None
// for ops that do not have gradients.
if (loop_state != null)
loop_state.EnterGradWhileContext(op, before: true);
var out_grads = _AggregatedGrads(grads, op, gradient_uid, loop_state, aggregation_method);
if (loop_state != null)
loop_state.ExitGradWhileContext(op, before: true);

Func<Operation, Tensor[], Tensor[]> grad_fn = null;
try
{
grad_fn = ops.get_gradient_function(op);
}
catch (LookupError)
Tensor[] in_grads = null;
var is_partitioned_call = _IsPartitionedCall(op);
var is_func_call = false;
var has_out_grads = out_grads.Exists(x => x != null);
if (has_out_grads && !stop_ops.Contains(op))
{
if (is_func_call)
// A grad_fn must be defined, either as a function or as None
// for ops that do not have gradients.

Func<Operation, Tensor[], Tensor[]> grad_fn = null;
try
{
if (is_partitioned_call)
grad_fn = ops.get_gradient_function(op);
}
catch (LookupError)
{
if (is_func_call)
{
if (is_partitioned_call)
{

}
else
{

}
}
else
{

throw new LookupError($"No gradient defined for operation '{op.name}' (op type: {op.type})");
}
}
else
{
throw new LookupError($"No gradient defined for operation '{op.name}' (op type: {op.type})");
}
}

if (loop_state != null)
loop_state.EnterGradWhileContext(op, before: false);
if (loop_state != null)
loop_state.EnterGradWhileContext(op, before: false);

if ((is_func_call || grad_fn != null) && has_out_grads)
{
// NOTE: If _AggregatedGrads didn't compute a value for the i'th
// output, it means that the cost does not depend on output[i],
// therefore dC/doutput[i] is 0.
foreach (var (i, out_grad) in enumerate(out_grads))
if ((is_func_call || grad_fn != null) && has_out_grads)
{
if (out_grad == null &&
(grad_fn == null || _IsTrainable(op.outputs[i])))
// NOTE: If _AggregatedGrads didn't compute a value for the i'th
// output, it means that the cost does not depend on output[i],
// therefore dC/doutput[i] is 0.
foreach (var (i, out_grad) in enumerate(out_grads))
{
// Only trainable outputs or outputs for a function call that
// will use SymbolicGradient get a zero gradient. Gradient
// functions should ignore the gradient for other outputs.
if (loop_state != null)
out_grads[i] = new List<Tensor> { loop_state.ZerosLike(op, i) };
else
out_grads[i] = new List<Tensor> { control_flow_ops.ZerosLikeOutsideLoop(op, i) };
if (out_grad == null &&
(grad_fn == null || _IsTrainable(op.outputs[i])))
{
// Only trainable outputs or outputs for a function call that
// will use SymbolicGradient get a zero gradient. Gradient
// functions should ignore the gradient for other outputs.
if (loop_state != null)
out_grads[i] = new List<Tensor> { loop_state.ZerosLike(op, i) };
else
out_grads[i] = new List<Tensor> { control_flow_ops.ZerosLikeOutsideLoop(op, i) };
}
}
}

tf_with(ops.name_scope(op.name + "_grad"), scope1 =>
{
if (grad_fn != null)
tf_with(ops.name_scope(op.name + "_grad"), scope1 =>
{
in_grads = _MaybeCompile(grad_scope,
op,
out_grads.Where(x => x != null).Select(x => x[0]).ToArray(),
null,
grad_fn);
}
else
{
throw new NotImplementedException("lambda: _SymGrad(op, out_grads)");
}
_VerifyGeneratedGradients(in_grads, op);
if (gate_gradients && in_grads.Count(x => x != null) > 1)
{
ops._colocate_with_for_gradient(null, gradient_uid, ignore_existing: true);
in_grads = control_flow_ops.tuple(in_grads);
}
});
if (grad_fn != null)
{
in_grads = _MaybeCompile(grad_scope,
op,
out_grads.Where(x => x != null).Select(x => x[0]).ToArray(),
null,
grad_fn);
}
else
{
throw new NotImplementedException("lambda: _SymGrad(op, out_grads)");
}
_VerifyGeneratedGradients(in_grads, op);
if (gate_gradients && in_grads.Count(x => x != null) > 1)
{
ops._colocate_with_for_gradient(null, gradient_uid, ignore_existing: true);
in_grads = control_flow_ops.tuple(in_grads);
}
});
}
else
{
// If no grad_fn is defined or none of out_grads is available,
// just propagate a list of None backwards.
in_grads = new Tensor[_NonEagerInputs(op, xs).Count()];
}
}
else
{
// If no grad_fn is defined or none of out_grads is available,
// just propagate a list of None backwards.
in_grads = new Tensor[_NonEagerInputs(op, xs).Count()];
}
}
else
{
in_grads = new Tensor[_NonEagerInputs(op, xs).Count()];
}

var inputs = _NonEagerInputs(op, xs).ToList();
foreach (var (t_in, in_grad) in zip(inputs, in_grads))
{
if (in_grad != null)
var inputs = _NonEagerInputs(op, xs).ToList();
foreach (var (t_in, in_grad) in zip(inputs, in_grads))
{
if (!(in_grad is null) &&
in_grad.Tag == null && // maybe a IndexedSlice
t_in.dtype != TF_DataType.TF_RESOURCE)
if (in_grad != null)
{
in_grad.set_shape(t_in.TensorShape);
}
if (!(in_grad is null) &&
in_grad.Tag == null && // maybe a IndexedSlice
t_in.dtype != TF_DataType.TF_RESOURCE)
{
in_grad.set_shape(t_in.TensorShape);
}

_SetGrad(grads, t_in, in_grad);
_SetGrad(grads, t_in, in_grad);
}
}
}

if (loop_state != null)
loop_state.ExitGradWhileContext(op, before: false);
}
// Update pending count for the inputs of op and enqueue ready ops.
_UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state, xs);
}
});

if (loop_state != null)
loop_state.PostProcessing();
return xs.Select(x => _GetGrad(grads, x)).ToArray();
}



+ 4
- 3
src/TensorFlowNET.Core/Layers/Layer.cs View File

@@ -50,10 +50,11 @@ namespace Tensorflow.Layers

public virtual (Tensor, Tensor) apply(Tensor inputs, Tensor training = null)
{
return __call__(inputs, training: training);
var results = __call__(inputs, training: training);
return (results[0], results[1]);
}

public (Tensor, Tensor) __call__(Tensor inputs,
public Tensor[] __call__(Tensor inputs,
Tensor training = null,
Tensor state = null,
VariableScope scope = null)
@@ -73,7 +74,7 @@ namespace Tensorflow.Layers
auxiliary_name_scope: false);
}

(Tensor, Tensor) outputs = (null, null);
Tensor[] outputs = null;
tf_with(scope_context_manager, scope2 =>
{
_current_scope = scope2;


+ 38
- 20
src/TensorFlowNET.Core/Operations/control_flow_ops.cs View File

@@ -151,27 +151,50 @@ namespace Tensorflow
/// <param name="colocate_gradients_with_ops"></param>
public static ControlFlowState MaybeCreateControlFlowState(List<Operation> between_op_list, List<Operation> between_ops, bool colocate_gradients_with_ops)
{
var flag = new List<Operation>();
ControlFlowState loop_state = null;

foreach (var op in between_op_list)
int pos = 0;
while(pos < between_op_list.Count)
{
var op = between_op_list[pos];
if (IsLoopExit(op))
{
if(loop_state == null)
if (loop_state == null)
{
loop_state = new ControlFlowState();
}
if (colocate_gradients_with_ops)
ops.colocate_with(op);
loop_state.AddWhileContext(op, between_op_list, between_ops);
}
pos++;
}

return loop_state;
}

public static bool IsLoopExit(Operation op)
=> op.OpType == "Exit" || op.OpType == "RefExit";

public static bool IsLoopSwitch(Operation op)
{
if(IsSwitch(op))
{
var ctxt = op._get_control_flow_context();
return ctxt != null && ctxt.IsWhileContext() && !IsCondSwitch(op);
}
return false;
}

public static bool IsCondSwitch(Operation op)
{
return op.OpType == "Exit" || op.OpType == "RefExit";
throw new NotImplementedException("IsCondSwitch");
}

public static bool IsSwitch(Operation op)
=> op.type == "Switch" || op.type == "RefSwitch";

public static Tensor[] tuple(Tensor[] tensors, string name = null, Operation[] control_inputs = null)
{
return tf_with(ops.name_scope(name, "tuple", tensors), scope =>
@@ -224,15 +247,10 @@ namespace Tensorflow
//TODO: missing original code
//if context.executing_eagerly():
// return output_tensor
var values = new List<object>();
values.AddRange(dependencies);
values.Add(output_tensor);

return tf_with(ops.name_scope(name, "control_dependency", values), scope =>
return tf_with(ops.name_scope(name, "control_dependency", new { dependencies, output_tensor }), scope =>
{
name = scope;
// TODO: missing original code
//with ops.colocate_with(output_tensor):
ops.colocate_with(output_tensor);
{
return tf_with(ops.control_dependencies(dependencies), ctl =>
{
@@ -431,6 +449,7 @@ namespace Tensorflow
var merges = zip(res_f_flat, res_t_flat)
.Select(pair => merge(new Tensor[] { pair.Item1, pair.Item2 }))
.Select(m => (Tensor)m)
.ToArray();

var merges2 = _convert_flows_to_tensorarrays(new ITensorOrTensorArray[] { (Tensor)orig_res_t }, merges);
@@ -479,6 +498,7 @@ namespace Tensorflow

var merges = zip(res_f_flat, res_t_flat)
.Select(pair => merge(new [] { pair.Item1, pair.Item2 }))
.Select(m => (Tensor)m)
.ToArray();

var merges2 = _convert_flows_to_tensorarrays(orig_res_t.Select(x => (ITensorOrTensorArray)x).ToArray(), merges);
@@ -519,7 +539,7 @@ namespace Tensorflow
/// <param name="inputs">inputs: The input tensors, at most one of which is available.</param>
/// <param name="name">A name for this operation (optional).</param>
/// <returns></returns>
public static Tensor merge(Tensor[] inputs, string name = null)
public static MergeOutput merge(Tensor[] inputs, string name = null)
{
if (inputs.Any(x => x == null))
throw new ValueError($"At least one of the merge inputs is null: {inputs}");
@@ -529,7 +549,7 @@ namespace Tensorflow
inputs = inputs.Select(inp =>
ops.internal_convert_to_tensor_or_indexed_slices(inp, as_ref: true))
.ToArray();
return gen_control_flow_ops.merge(inputs, name)[0];
return gen_control_flow_ops.merge(inputs, name);
});
}

@@ -602,7 +622,7 @@ namespace Tensorflow
/// <param name="body"></param>
/// <param name="loop_vars"></param>
/// <param name="i"></param>
public static Tensor while_loop<TItem>(Func<TItem, Tensor> cond, Func<TItem, TItem> body, TItem loop_vars,
public static TItem while_loop<TItem>(Func<TItem, Tensor> cond, Func<TItem, TItem> body, TItem loop_vars,
TensorShape[] shape_invariants = null,
int parallel_iterations = 10,
bool back_prop = true,
@@ -611,7 +631,7 @@ namespace Tensorflow
Tensor maximum_iterations = null,
bool return_same_structure = false)
{
tf_with(ops.name_scope(name, "while", loop_vars), scope =>
return tf_with(ops.name_scope(name, "while", loop_vars), scope =>
{
if (loop_vars == null)
throw new ValueError("No loop variables provided");
@@ -666,13 +686,11 @@ namespace Tensorflow
var results = loop_context.BuildLoop(cond_buildloop, body_buildloop, loop_vars_1, shape_invariants,
return_same_structure);

if (maximum_iterations != null)
return results[1];
else
return results[0];
//if (maximum_iterations != null)
return results.Item;
//else
//return results;
});

throw new NotImplementedException("while_loop");
}

/// <summary>


+ 93
- 2
src/TensorFlowNET.Core/Operations/control_flow_util.py.cs View File

@@ -15,6 +15,7 @@
******************************************************************************/

using System;
using System.Linq;
using Tensorflow.Operations;
using static Tensorflow.Binding;

@@ -60,6 +61,45 @@ namespace Tensorflow
public static bool IsSwitch(Operation op)
{
return op.type == "Switch" || op.type == "RefSwitch";
}
public static WhileContext GetWhileContext(Operation op)
=> op.GetWhileContext();

public static bool IsCondSwitch(Operation op)
{
if (!IsSwitch(op))
return false;
if (op.outputs == null || op.outputs.Length == 0)
return false;
// Switch nodes are not part of the cond control flow context that they
// represent, so consider the consumers of its outputs to determine if it is
// cond switch or not. A switch is a cond switch iff all its consumers are in
// cond contexts.
var is_cond_switch = true;
foreach(var o in op.outputs)
{
foreach(var c in o.consumers())
{
var ctxt = c._get_control_flow_context();
if (IsLoopEnter(c))
ctxt = ctxt.outer_context;
is_cond_switch = is_cond_switch &&(ctxt != null && ctxt.IsCondContext());
}
}
return is_cond_switch;
}

public static bool IsLoopSwitch(Operation op)
{
if (IsSwitch(op))
{
var ctxt = op._get_control_flow_context();
return ctxt != null && ctxt.IsWhileContext() && !IsCondSwitch(op);
}
return false;
}

/// <summary>
@@ -87,13 +127,64 @@ namespace Tensorflow
valid = true;
else
{
throw new NotImplementedException("");
var while_ctxt = GetContainingWhileContext(op_ctxt);
var input_while_ctxt = GetContainingWhileContext(input_ctxt);
if (while_ctxt == null)
{
throw new NotImplementedException("CheckInputFromValidContext");
}
else if (IsContainingContext(while_ctxt, input_while_ctxt))
{
// input_op is in a while loop which contains op's while loop (or not in a
// while loop at all).
valid = true;
}
else if (while_ctxt.grad_state != null &&
IsContainingContext(while_ctxt.grad_state.forward_context,
input_while_ctxt))
{
valid = true;
}
else
throw new NotImplementedException("CheckInputFromValidContext");
}
if (!valid)
{
throw new NotImplementedException("");
throw new NotImplementedException("CheckInputFromValidContext");
}
}
public static Operation GetLoopConstantEnter(Tensor value)
{
var id_ops = new string[] { "Switch", "RefSwitch", "Identity", "RefIdentity" };
var op = value.op;
while (id_ops.Contains(op.type))
op = op.inputs[0].op;
return IsLoopConstantEnter(op) ? op : null;
}

public static bool IsContainingContext(WhileContext ctxt, WhileContext maybe_containing_ctxt)
{
while(ctxt != maybe_containing_ctxt)
{
if (ctxt == null)
return false;
ctxt = ctxt.outer_context as WhileContext;
}
return true;
}

public static WhileContext GetContainingWhileContext(ControlFlowContext ctxt, ControlFlowContext stop_ctxt = null)
{
while (ctxt != null)
{
if (ctxt.IsWhileContext() || ctxt == stop_ctxt)
return ctxt as WhileContext;
ctxt = ctxt.outer_context;
}
return null;
}
}
}

+ 2
- 0
src/TensorFlowNET.Core/Operations/math_ops.cs View File

@@ -159,6 +159,8 @@ namespace Tensorflow
});
}

public static Tensor greater_equal<Tx, Ty>(Tx x, Ty y, string name = null)
=> gen_math_ops.greater_equal<Tx, Ty>(x, y, name: name);
public static Tensor equal<Tx, Ty>(Tx x, Ty y, string name = null)
=> gen_math_ops.equal(x, y, name: name);



Loading…
Cancel
Save