Browse Source

gradients/flow control: added much missing structure

tags/v0.9
Meinrad Recheis 6 years ago
parent
commit
692faaa9bd
13 changed files with 1014 additions and 50 deletions
  1. +225
    -4
      src/TensorFlowNET.Core/Gradients/control_flow_grad.py.cs
  2. +4
    -3
      src/TensorFlowNET.Core/Graphs/Graph.Control.cs
  3. +2
    -1
      src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs
  4. +37
    -14
      src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs
  5. +27
    -6
      src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs
  6. +277
    -0
      src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowState.cs
  7. +398
    -0
      src/TensorFlowNET.Core/Operations/ControlFlows/GradLoopState.cs
  8. +11
    -9
      src/TensorFlowNET.Core/Operations/ControlFlows/IControlFlowContext.cs
  9. +15
    -0
      src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs
  10. +3
    -3
      src/TensorFlowNET.Core/Operations/Operation.Control.cs
  11. +12
    -7
      src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs
  12. +1
    -1
      src/TensorFlowNET.Core/Operations/control_flow_util.py.cs
  13. +2
    -2
      src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs

+ 225
- 4
src/TensorFlowNET.Core/Gradients/control_flow_grad.py.cs View File

@@ -1,12 +1,79 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Operations;

namespace Tensorflow.Gradients
{
/// <summary>
/// Gradients for operators defined in control_flow_ops.py.cs
/// </summary>
public class control_flow_grad
{
/// <summary>
/// Gradients for a Switch op is calculated using a Merge op.
///
/// If the switch is a loop switch, it will be visited twice. We create
/// the merge on the first visit, and update the other input of the merge
/// on the second visit. A next_iteration is also added on second visit.
/// </summary>
/// <returns></returns>
public Tensor[] _SwitchGrad(Tensor op, Tensor[] grads)
{
throw new NotImplementedException("_SwitchGrad");
//graph = ops.get_default_graph()
//# pylint: disable=protected-access
//op_ctxt = op._get_control_flow_context()
//grad_ctxt = graph._get_control_flow_context()
//# pylint: enable=protected-access
//if isinstance(op_ctxt, WhileContext):
// merge_grad = grad_ctxt.grad_state.switch_map.get(op)
// if merge_grad is not None:
// # This is the second time this Switch is visited. It comes from
// # the non-exit branch of the Switch, so update the second input
// # to the Merge.
// # TODO(yuanbyu): Perform shape inference with this new input.
// if grad[1] is not None:
// # pylint: disable=protected-access
// control_flow_ops._AddNextAndBackEdge(merge_grad, grad[1],
// enforce_shape_invariant=False)
// # pylint: enable=protected-access
// return None, None
// elif grad[0] is not None:
// # This is the first time this Switch is visited. It comes from
// # the Exit branch, which is grad[0]. grad[1] is empty at this point.
// # Use grad[0] for both inputs to merge for now, but update the second
// # input of merge when we see this Switch the second time.
// merge_grad = merge([grad[0], grad[0]], name="b_switch")[0]
// grad_ctxt.grad_state.switch_map[op] = merge_grad
// return merge_grad, None
// else:
// # This is the first time this Switch is visited. It comes from the
// # Identity branch. Such a Switch has `None` gradient for the Exit branch,
// # meaning the output is not differentiable.
// return None, None
//elif isinstance(op_ctxt, CondContext):
// zero_grad = grad[1 - op_ctxt.branch]
// # At this point, we have created zero_grad guarded by the right switch.
// # Unfortunately, we may still get None here for not trainable data types.
// if zero_grad is None:
// # For resource variables we get None always on the other branch, so bypass
// # this.
// if op.inputs[0].dtype == dtypes.resource:
// return merge(
// [grad[op_ctxt.branch]] * 2, name="cond_resource_grad")[0], None
// return None, None
// return merge(grad, name="cond_grad")[0], None
//else:
// false_grad = switch(grad[0], op.inputs[1])[0]
// true_grad = switch(grad[1], op.inputs[1])[1]
// return merge([false_grad, true_grad])[0], None
}

/// <summary>
/// Gradients for a Merge op are calculated using a Switch op.
/// </summary>
public static Tensor[] _MergeGrad(Operation op, Tensor[] grads)
{
var grad = grads[0];
@@ -14,10 +81,164 @@ namespace Tensorflow.Gradients
var input_op = op.inputs[0].op;
var graph = ops.get_default_graph();
var op_ctxt = control_flow_util.GetOutputContext(input_op);
var pred = (op_ctxt as CondContext).pred;
var grad_ctxt = graph._get_control_flow_context();
switch (op_ctxt)
{
case WhileContext cwhile:
{
return control_flow_ops._SwitchRefOrTensor(grad, grad_ctxt.pivot);
}
case CondContext ccond:
{
var pred = ccond.pred;
if (grad_ctxt != null && grad_ctxt.grad_state != null)
{
//# This Merge node is part of a cond within a loop.
//# The backprop needs to have the value of this predicate for every
//# iteration. So we must have its values accumulated in the forward, and
//# use the accumulated values as the predicate for this backprop switch.
var grad_state = grad_ctxt.grad_state;
var real_pred = grad_state.history_map[pred.name] as Tensor;
if (real_pred == null)
{
//# Remember the value of pred for every iteration.
grad_ctxt = grad_state.grad_context;
grad_ctxt.Exit();
var history_pred = grad_state.AddForwardAccumulator(pred);
grad_ctxt.Enter();
//# Add the stack pop op. If pred.op is in a (outer) CondContext,
//# the stack pop will be guarded with a switch.
real_pred = grad_state.AddBackpropAccumulatedValue(history_pred, pred);
grad_state.history_map[pred.name] = real_pred;
}
pred = real_pred;
}
var results = control_flow_ops._SwitchRefOrTensor(grad, pred, name: "cond_grad");
return results;
}
default:
{
var num_inputs = op.inputs.Length;
var cond = new Tensor[num_inputs];
for (int i = 0; i < num_inputs; i++)
cond[i] = math_ops.equal(op.outputs[1], i);
var result = cond.Select(t => control_flow_ops._SwitchRefOrTensor(grad, t)[1]).ToArray();
return result;
}
}

var results = control_flow_ops._SwitchRefOrTensor(grad, pred, name: "cond_grad");
return new Tensor[] { results.Item1, results.Item2 };
}
}

public Tensor[] _RefMergeGrad(Operation op, Tensor[] grads)
{
return _MergeGrad(op, grads);
}

/// <summary>
/// Gradients for an exit op are calculated using an Enter op.
/// </summary>
public Tensor[] _ExitGrad(Operation op, Tensor[] grads)
{
throw new NotImplementedException("_ExitGrad");
// graph = ops.get_default_graph()
//# pylint: disable=protected-access
// op_ctxt = op._get_control_flow_context()
// grad_ctxt = graph._get_control_flow_context()
// # pylint: enable=protected-access
// if not grad_ctxt.back_prop:
// # The flag `back_prop` is set by users to suppress gradient
// # computation for this loop. If the attribute `back_prop` is false,
// # no gradient computation.
// return None

// if op_ctxt.grad_state:
// raise TypeError("Second-order gradient for while loops not supported.")

// if isinstance(grad, ops.Tensor) :
// grad_ctxt.AddName(grad.name)
// else:
// if not isinstance(grad, (ops.IndexedSlices, sparse_tensor.SparseTensor)):
// raise TypeError("Type %s not supported" % type(grad))
// grad_ctxt.AddName(grad.values.name)
// grad_ctxt.AddName(grad.indices.name)
// dense_shape = grad.dense_shape
// if dense_shape is not None:
// grad_ctxt.AddName(dense_shape.name)
// grad_ctxt.Enter()
// # pylint: disable=protected-access
// result = control_flow_ops._Enter(
// grad, grad_ctxt.name, is_constant=False,
// parallel_iterations=grad_ctxt.parallel_iterations,
// name="b_exit")
// # pylint: enable=protected-access
// grad_ctxt.loop_enters.append(result)
// grad_ctxt.Exit()
// return result
}

/// <summary>
/// A forward next_iteration is translated into a backprop identity.
///
/// Note that the backprop next_iteration is added in switch grad.
/// </summary>
public (object, Tensor[]) _NextIterationGrad(object _, Tensor[] grad)
{
return (_, grad);
}

public (object, Tensor[]) _RefNextIterationGrad(object _, Tensor[] grad)
{
return (_, grad);
}

/// <summary>
/// Gradients for an Enter are calculated using an Exit op.
///
/// For loop variables, grad is the gradient so just add an exit.
/// For loop invariants, we need to add an accumulator loop.
/// </summary>
public (object, Tensor[]) _EnterGrad(Tensor op, Tensor[] grad)
{
throw new NotImplementedException("_EnterGrad");
// graph = ops.get_default_graph()
//# pylint: disable=protected-access
// grad_ctxt = graph._get_control_flow_context()
// # pylint: enable=protected-access
// if not grad_ctxt.back_prop:
// # Skip gradient computation, if the attribute `back_prop` is false.
// return grad
// if grad_ctxt.grad_state is None:
// # Pass the gradient through if we are not in a gradient while context.
// return grad
// if op.get_attr("is_constant"):
// # Add a gradient accumulator for each loop invariant.
// if isinstance(grad, ops.Tensor) :
// result = grad_ctxt.AddBackpropAccumulator(op, grad)
// elif isinstance(grad, ops.IndexedSlices) :
// result = grad_ctxt.AddBackpropIndexedSlicesAccumulator(op, grad)
// else:
// # TODO(yuanbyu, lukasr): Add support for SparseTensor.
// raise TypeError("Type %s not supported" % type(grad))
// else:
// result = exit(grad)
// grad_ctxt.loop_exits.append(result)
// grad_ctxt.ExitResult([result])
// return result
}

public (object, Tensor[]) _RefEnterGrad(Tensor op, Tensor[] grad)
{
return _EnterGrad(op, grad);
}

/// <summary>
/// Stop backprop for the predicate of a while loop.
/// </summary>
public object _LoopCondGrad(object _)
{
return null;
}
}
}

+ 4
- 3
src/TensorFlowNET.Core/Graphs/Graph.Control.cs View File

@@ -3,13 +3,14 @@ using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Eager;
using Tensorflow.Operations;

namespace Tensorflow
{
public partial class Graph
{
// Current control flow context. It could be either CondContext or WhileContext
public IControlFlowContext _control_flow_context;
public ControlFlowContext _control_flow_context;

// represents the nested with(...) statements
public List<_ControlDependenciesController> _control_dependencies_stack { get; set; } = new List<_ControlDependenciesController>();
@@ -97,7 +98,7 @@ namespace Tensorflow
/// Returns the current control flow context.
/// </summary>
/// <returns>A context object.</returns>
public IControlFlowContext _get_control_flow_context()
public ControlFlowContext _get_control_flow_context()
{
return _control_flow_context;
}
@@ -106,7 +107,7 @@ namespace Tensorflow
/// Sets the current control flow context.
/// </summary>
/// <param name="ctx">a context object.</param>
public void _set_control_flow_context(IControlFlowContext ctx)
public void _set_control_flow_context(ControlFlowContext ctx)
{
_control_flow_context = ctx;
}


+ 2
- 1
src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs View File

@@ -2,6 +2,7 @@
using System.Collections.Generic;
using System.Text;
using Tensorflow.Eager;
using Tensorflow.Operations;

namespace Tensorflow
{
@@ -15,7 +16,7 @@ namespace Tensorflow
private List<ITensorOrOperation> _seen_nodes;
private List<_ControlDependenciesController> _old_stack;
private bool _new_stack;
private IControlFlowContext _old_control_flow_context;
private ControlFlowContext _old_control_flow_context;

public ITensorOrOperation[] control_inputs => _control_inputs_val.ToArray();


+ 37
- 14
src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs View File

@@ -2,6 +2,7 @@
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Operations.ControlFlows;

namespace Tensorflow.Operations
{
@@ -107,8 +108,8 @@ namespace Tensorflow.Operations
with(ops.control_dependencies(null), ctrl =>
{
var (r0, r1) = control_flow_ops._SwitchRefOrTensor(result, _pred);
result = new[] { r0, r1 }[_branch];
var results = control_flow_ops._SwitchRefOrTensor(result, _pred);
result = results[_branch];
if (_outer_context != null)
_outer_context.AddInnerOp(result.op);
});
@@ -118,7 +119,7 @@ namespace Tensorflow.Operations

// Mark Switch output as seen by this context and any outer contexts,
// just like what we do for normal op outputs in _AddOpInternal() below.
IControlFlowContext ctxt = this;
ControlFlowContext ctxt = this;
while (ctxt != null)
{
ctxt.values.Add(result.name);
@@ -223,8 +224,8 @@ namespace Tensorflow.Operations
_values.Add(real_val.name);
_external_values[real_val.name] = real_val;
}
var (t0, t1) = control_flow_ops._SwitchRefOrTensor(real_val, _pred);
real_val = new[] {t0, t1}[_branch];
var results = control_flow_ops._SwitchRefOrTensor(real_val, _pred);
real_val = results[_branch];
_external_values[val.name] = real_val;
}
else
@@ -238,8 +239,8 @@ namespace Tensorflow.Operations
return real_val;
}
protected override void _AddOpInternal(Operation op)
{
protected override void _AddOpInternal(Operation op)
{
if (op.inputs.Length == 0)
{
//If we're in a while loop, remove any control inputs from outside the
@@ -282,11 +283,11 @@ namespace Tensorflow.Operations
// TODO: implement below code dependencies
//if (op.graph._is_function(op.type) || op.type == "SymbolicGradient")
// op._add_control_input(_pivot.op);
}
// Mark op's outputs as seen by this context and any outer contexts.
}
// Mark op's outputs as seen by this context and any outer contexts.
var output_names = op.outputs.Select(x => x.name).ToArray();
IControlFlowContext ctxt = this;
ControlFlowContext ctxt = this;
while (ctxt != null)
{
foreach (var name in output_names)
@@ -298,9 +299,31 @@ namespace Tensorflow.Operations
op.graph.prevent_fetching(op);

if (_outer_context != null)
_outer_context.AddInnerOp(op);
}
_outer_context.AddInnerOp(op);
}

public override GradLoopState grad_state
{
get
{
var whc = GetWhileContext();
if (whc != null)
return whc.grad_state;
return null;
}
}

public override bool back_prop
{
get
{
var whc = GetWhileContext();
if (whc != null)
return whc.back_prop;
return false;
}
}

public CondContextDef to_proto(string export_scope)
{
throw new NotImplementedException();


+ 27
- 6
src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs View File

@@ -2,6 +2,7 @@
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Operations.ControlFlows;

namespace Tensorflow.Operations
{
@@ -22,21 +23,25 @@ namespace Tensorflow.Operations
/// 4. A ControlFlowContext has _context_stack.
/// Pushed and popped by ctxt.Enter() and ctxt.Exit()
/// </summary>
public abstract class ControlFlowContext : Python, IPython, IControlFlowContext
public abstract class ControlFlowContext : Python, IPython
{
/// <summary>
/// The predicate tensor in this branch
/// </summary>
protected Tensor _pivot;
public Tensor pivot
{
get => _pivot;
}

protected Stack<IControlFlowContext> _context_stack;
protected IControlFlowContext _outer_context;
protected Stack<ControlFlowContext> _context_stack;
protected ControlFlowContext _outer_context;

protected Dictionary<string, ITensorOrOperation> _external_values;

public ControlFlowContext()
{
_context_stack = new Stack<IControlFlowContext>();
_context_stack = new Stack<ControlFlowContext>();
}

public string name { get => _name; }
@@ -111,8 +116,13 @@ namespace Tensorflow.Operations
_AddOpInternal(op);
}

public IControlFlowContext outer_context { get { return _outer_context; } }
public ControlFlowContext outer_context { get { return _outer_context; } }
public HashSet<string> values => _values;

public virtual GradLoopState grad_state => throw new NotImplementedException("abstract method");

public virtual bool back_prop => throw new NotImplementedException("abstract method");

public virtual Tensor AddValue(Tensor val)
{
// to be overridden
@@ -147,7 +157,7 @@ namespace Tensorflow.Operations
/// <summary>
/// Returns true if `maybe_containing_ctxt` is or contains `ctxt`.
/// </summary>
public static bool IsContainingContext(IControlFlowContext ctxt, ControlFlowContext maybe_containing_ctxt)
public static bool IsContainingContext(ControlFlowContext ctxt, ControlFlowContext maybe_containing_ctxt)
{
while (ctxt != maybe_containing_ctxt)
{
@@ -164,6 +174,16 @@ namespace Tensorflow.Operations
var internal_control_inputs = op.control_inputs;
}

/// <summary>
/// Return the while context containing this context
/// </summary>
public virtual WhileContext GetWhileContext()
{
if (_outer_context != null)
return _outer_context.GetWhileContext();
return null;
}

public object to_proto()
{
throw new NotImplementedException();
@@ -173,5 +193,6 @@ namespace Tensorflow.Operations
public void Dispose()
{
}

}
}

+ 277
- 0
src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowState.cs View File

@@ -0,0 +1,277 @@
using System;
using System.Collections.Generic;
using System.Text;
namespace Tensorflow.Operations.ControlFlows
{
/// <summary>
/// Maintain the mapping from the loops to their grad states.
/// </summary>
public class ControlFlowState
{
//class ControlFlowState(object):
// """Maintain the mapping from the loops to their grad states."""
// def __init__(self):
// self._map = {} # maps forward loop context to GradLoopState
// def GetGradState(self, op, before):
// """Return the grad state for this op if it's in a forward loop context."""
// if before and util.IsLoopExit(op):
// forward_ctxt = op._get_control_flow_context()
// forward_ctxt = forward_ctxt.outer_context
// if forward_ctxt:
// forward_ctxt = forward_ctxt.GetWhileContext()
// else:
// forward_ctxt = _GetWhileContext(op)
// if forward_ctxt:
// return self._map.get(forward_ctxt)
// return None
// def ProcessUnusedLoopExits(self, pending_count, to_ops_set):
// """Process all the "unused" loop exits.
// The "unused" exits of the loops are added to `unused_exits`. An exit is
// unused if its pending_count is 0. If there is an exit with real gradient,
// all these deferred exits will enter the backprop loop with zero gradient.
// Otherwise, they will enter the backprop loop with None. As an example,
// people often write:
// ```python
// v1, _ = tf.while_loop(p, b, [x1, x2])
// result = gradients(v1, x1)
// ```
// The exit node for x2 is not included by the betweenness analysis. But we
// need to backprop x2 if x2 is involved in computing v1.
// Args:
// pending_count: The number of backprop inputs for every op.
// to_ops_set: The set of ops for ys in gradients(ys, xs)
// Returns:
// The set of unused loop exits that we know at this point we need
// to backprop.
// """
// loop_exits = []
// for grad_state in self._map.values():
// for y in grad_state.forward_loop_exits:
// if pending_count[y.op] == 0:
// grad_state.pending_exits_count -= 1
// if y.op not in to_ops_set:
// grad_state.unused_exits.append(y)
// if grad_state.pending_exits_count == 0:
// loop_exits.extend(grad_state.unused_exits)
// # Need to include Enters in backprop for higher-order gradients.
// for y in grad_state.forward_context.loop_enters:
// if pending_count[y.op] == 0:
// pending_count[y.op] = 1
// return loop_exits
// def EnterGradWhileContext(self, op, before):
// """Enter the WhileContext for gradient computation."""
// grad_state = self.GetGradState(op, before)
// if grad_state:
// grad_state.grad_context.Enter()
// def ExitGradWhileContext(self, op, before):
// """Exit the WhileContext for gradient computation."""
// grad_state = self.GetGradState(op, before)
// if grad_state:
// grad_state.grad_context.Exit()
// def AddWhileContext(self, op, between_op_list, between_ops):
// """Add the grad state for the while loop that op belongs to.
// Note that op is an Exit, and this method must be called in
// the control flow context where gradients() is called.
// Note that this method modifies `between_op_list` and `between_ops`.
// """
// forward_ctxt = _GetWhileContext(op)
// grad_state = self._map.get(forward_ctxt)
// if grad_state is None:
// # This is a new while loop so create a grad state for it.
// outer_forward_ctxt = forward_ctxt.outer_context
// if outer_forward_ctxt:
// outer_forward_ctxt = outer_forward_ctxt.GetWhileContext()
// outer_grad_state = None
// if outer_forward_ctxt:
// outer_grad_state = self._map.get(outer_forward_ctxt)
// grad_state = GradLoopState(forward_ctxt, outer_grad_state)
// self._map[forward_ctxt] = grad_state
// # We need to include all exits of a loop for backprop.
// for loop_exit in grad_state.forward_loop_exits:
// if loop_exit.op not in between_ops:
// between_ops.add(loop_exit.op)
// between_op_list.append(loop_exit.op)
// def ZerosLikeForExit(self, val):
// """Create zeros_like gradient for a loop exit.
// If the result of a loop variable is not used but is involved in
// computing the result of some needed loop variable, we create a
// zero-valued tensor that is fed as gradient for the Exit node of that
// loop variable. Note that val.op is an Exit, and this method must be
// called in the control flow context where gradients() is called.
// Args:
// val: The output tensor of an Exit op.
// Returns:
// A zero tensor of the same shape of val.
// """
// val_shape = val.get_shape()
// forward_ctxt = val.op._get_control_flow_context()
// outer_forward_ctxt = forward_ctxt.outer_context
// if outer_forward_ctxt:
// outer_forward_ctxt = outer_forward_ctxt.GetWhileContext()
// outer_grad_state = None
// if outer_forward_ctxt:
// outer_grad_state = self._map.get(outer_forward_ctxt)
// if outer_grad_state:
// # This is a nested loop.
// if val_shape.is_fully_defined():
// # If the shape is known statically, just create a zero tensor
// # with the right shape in the right context.
// outer_grad_state.grad_context.Enter()
// result = array_ops.zeros(val_shape.dims, val.dtype)
// outer_grad_state.grad_context.Exit()
// else:
// # Only the shape of value is needed for backprop.
// forward_ctxt.outer_context.Enter()
// shape = array_ops.shape_internal(val, optimize=False)
// forward_ctxt.outer_context.Exit()
// # Save the shape to a stack.
// history_shape = outer_grad_state.AddForwardAccumulator(shape)
// # Get the shape back from the stack.
// outer_grad_ctxt = outer_grad_state.grad_context
// outer_grad_ctxt.Enter()
// real_shape = outer_grad_state.AddBackpropAccumulatedValue(
// history_shape, shape)
// result = array_ops.zeros(real_shape, val.dtype)
// outer_grad_ctxt.Exit()
// else:
// # This is not a nested loop.
// if val_shape.is_fully_defined():
// # If the shape is known statically, just create a zero tensor
// # with the right shape.
// result = array_ops.zeros(val_shape.dims, val.dtype)
// else:
// result = array_ops.zeros_like(val, optimize=False)
// return result
// def ZerosLike(self, op, index):
// """Create zeros_like for the specified output of an op.
// If op is in a while loop that is part of gradients(), this method
// must be called in its grad loop context.
// Args:
// op: A tensorflow operation.
// index: the index for a specific output of the op.
// Returns:
// A zero tensor of the same shape of op.outputs[index].
// """
// if util.IsLoopSwitch(op):
// return None
// if op.graph._building_function: # pylint: disable=protected-access
// # The optimization here is tricky to apply to functions
// return array_ops.zeros_like(op.outputs[index])
// dead_branch = util.IsSwitch(op)
// forward_ctxt = _GetWhileContext(op)
// grad_state = self._map.get(forward_ctxt)
// if grad_state is None:
// # op is not in a while loop that is part of gradients().
// return ZerosLikeOutsideLoop(op, index)
// op_ctxt = op._get_control_flow_context()
// val = ops.convert_to_tensor(op.outputs[index], name="tensor")
// shape = val.get_shape()
// if shape.is_fully_defined():
// # If the shape is known statically, just create a zero tensor with
// # the right shape in the grad loop context.
// result = constant_op.constant(0, shape=shape.dims, dtype=val.dtype)
// if dead_branch:
// # op is a cond switch. Guard the zero tensor with a switch.
// pred = grad_state.history_map.get(op_ctxt.pred.name)
// branch = op_ctxt.branch
// result = _SwitchRefOrTensor(result, pred)[1 - branch]
// else:
// # Unknown shape so keep a history of the shape at runtime.
// if dead_branch:
// # Need to add a special switch to guard the value.
// pred = op_ctxt.pred
// branch = op_ctxt.branch
// op_ctxt.outer_context.Enter()
// val = _SwitchRefOrTensor(op.inputs[0], pred)[1 - branch]
// zeros_shape = array_ops.shape_internal(val, optimize=False)
// op_ctxt.outer_context.Exit()
// val.op._set_control_flow_context(op_ctxt)
// zeros_shape.op._set_control_flow_context(op_ctxt)
// else:
// op_ctxt.Enter()
// zeros_shape = array_ops.shape_internal(val, optimize=False)
// op_ctxt.Exit()
// # Add forward accumulator for shape.
// grad_state.grad_context.Exit()
// history_zeros_shape = grad_state.AddForwardAccumulator(
// zeros_shape, dead_branch=dead_branch)
// grad_state.grad_context.Enter()
// # Create a zero tensor with the right shape.
// shape = grad_state.AddBackpropAccumulatedValue(history_zeros_shape,
// zeros_shape, dead_branch)
// result = array_ops.zeros(shape, val.dtype)
// return result
// def PostProcessing(self):
// """Perform postprocessing at the end of gradients().
// We have created the gradient graph at this point. So this function
// can be used to perform any postprocessing on the gradient graph.
// We currently perform the following postprocessing:
// 1. Patch the gradient graph if the output of a loop variable
// doesn't depend on its input.
// """
// for _, grad_state in self._map.items():
// for _, b_merge in grad_state.switch_map.items():
// if b_merge.op.inputs[0] == b_merge.op.inputs[1]:
// # The value of this loop variable at iteration i+1 doesn't
// # depend on its value at iteration i. So use zeros as the
// # gradients for all iterations > 0.
// dtype = b_merge.op.inputs[0].dtype
// shape = b_merge.op.inputs[0].get_shape()
// # pylint: disable=protected-access
// if shape.is_fully_defined():
// grad_state.grad_context.Enter()
// # Create a zeros and use it for iterations > 0.
// grad_val = constant_op.constant(0, dtype=dtype, shape=shape)
// next_grad_val = _NextIteration(grad_val)
// grad_state.grad_context.Exit()
// else:
// # Create a zeros in the outer grad context.
// outer_grad_ctxt = grad_state.grad_context.outer_context
// if outer_grad_ctxt:
// outer_grad_ctxt.Enter()
// enter_grad_op = b_merge.op.inputs[0].op
// enter_grad = enter_grad_op.inputs[0]
// grad_shape = array_ops.shape_internal(enter_grad, optimize=False)
// grad_val = array_ops.zeros(grad_shape)
// if outer_grad_ctxt:
// outer_grad_ctxt.Exit()
// # Use the zeros for iterations > 0.
// grad_state.grad_context.Enter()
// next_grad_val = _NextIteration(grad_val)
// grad_state.grad_context.Exit()
// b_merge.op._update_input(1, next_grad_val)
// # pylint: enable=protected-access
}
}

+ 398
- 0
src/TensorFlowNET.Core/Operations/ControlFlows/GradLoopState.cs View File

@@ -0,0 +1,398 @@
using System;
using System.Collections;
using System.Collections.Generic;
using System.Text;
namespace Tensorflow.Operations.ControlFlows
{
public class GradLoopState
{
//class GradLoopState(object):
// """The state used for constructing the gradient graph for a while loop.
// We create a GradLoopState for each while loop in forward and its
// corresponding while loop in backprop. This gives us access to both
// the forward and the backprop WhileContexts.
// During the construction of gradient graph, any time when we detect
// a forward value that is needed for backprop, we create a history
// accumulator and add it to `history_map`. Any time when we backprop
// a loop switch op (in _SwitchGrad), we add the grad merge op in
// `switch_map`.
// """
// def __init__(self, forward_ctxt, outer_grad_state):
// # The grad loop state for the outer while loop.
// self._outer_grad_state = None
// # The while loop context for forward.
// self._forward_context = None
// # The loop counter added by AddForwardLoopCounter. It is the value
// # of the loop counter for the next iteration.
// self._forward_index = None
// # A sync op for forward.
// self._forward_sync = None
// # The while loop context for backprop.
private WhileContext _grad_context = null;
public WhileContext grad_context => _grad_context;
// # The loop counter added by AddBackpropLoopCounter. It is the value
// # of the loop counter for the current iteration.
// self._grad_index = None
// # A sync op for backprop.
// self._grad_sync = None
// # Information needed by backprop.
private Hashtable _history_map = new Hashtable();
public Hashtable history_map => _history_map;
private Hashtable _switch_map = new Hashtable();
public Hashtable switch_map => _switch_map;
// self._unused_exits = []
// self._deferred_exits = []
// self._forward_loop_exits = list(forward_ctxt.loop_exits)
// self._pending_exits_count = len(forward_ctxt.loop_exits)
// self._outer_grad_state = outer_grad_state
// if outer_grad_state:
// outer_forward_ctxt = outer_grad_state.forward_context
// else:
// if not hasattr(forward_ctxt, "outer_context"):
// raise ValueError("Failed to call gradients on a while loop without"
// "properly serializing graph via MetaGraphDef")
// outer_forward_ctxt = forward_ctxt.outer_context
// # Add the forward loop counter.
// with forward_ctxt._graph.as_default(): # pylint: disable=protected-access
// if outer_forward_ctxt:
// outer_forward_ctxt.Enter()
// cnt, forward_index = forward_ctxt.AddForwardLoopCounter(outer_grad_state)
// if outer_forward_ctxt:
// outer_forward_ctxt.Exit()
// self._forward_context = forward_ctxt
// self._forward_index = forward_index
// # Add the backprop WhileContext, and the backprop loop counter.
// if outer_grad_state:
// # This is a nested loop. Remember the iteration counts for each
// # execution of this inner loop.
// outer_forward_ctxt.AddName(cnt.name)
// history_cnt = outer_grad_state.AddForwardAccumulator(cnt)
// outer_grad_ctxt = outer_grad_state.grad_context
// outer_grad_ctxt.Enter()
// self._grad_context = WhileContext(
// maximum_iterations=forward_ctxt.maximum_iterations,
// parallel_iterations=forward_ctxt.parallel_iterations,
// back_prop=forward_ctxt.back_prop,
// swap_memory=forward_ctxt.swap_memory,
// name=forward_ctxt.name,
// grad_state=self)
// real_cnt = outer_grad_state.AddBackpropAccumulatedValue(history_cnt, cnt)
// self._grad_index = self._grad_context.AddBackpropLoopCounter(
// real_cnt, outer_grad_state)
// outer_grad_ctxt.Exit()
// else:
// if outer_forward_ctxt:
// outer_forward_ctxt.Enter()
// self._grad_context = WhileContext(
// maximum_iterations=forward_ctxt.maximum_iterations,
// parallel_iterations=forward_ctxt.parallel_iterations,
// back_prop=forward_ctxt.back_prop,
// swap_memory=forward_ctxt.swap_memory,
// name=forward_ctxt.name,
// grad_state=self)
// self._grad_index = self._grad_context.AddBackpropLoopCounter(
// cnt, outer_grad_state)
// if outer_forward_ctxt:
// outer_forward_ctxt.Exit()
// @property
// def outer_grad_state(self):
// """The grad loop state for outer loop."""
// return self._outer_grad_state
// @property
// def forward_context(self):
// """The while loop context for forward."""
// return self._forward_context
// @property
// def forward_index(self):
// """The loop index of forward loop."""
// return self._forward_index
// @property
// def forward_sync(self):
// """A control trigger node for synchronization in the forward loop.
// One main use is to keep the push ops of a stack executed in the
// iteration order.
// """
// if self._forward_sync is None:
// with ops.control_dependencies(None):
// self._forward_sync = control_trigger(name="f_sync")
// self._forward_sync._set_control_flow_context(self._forward_context)
// self._forward_index.op._add_control_input(self._forward_sync)
// return self._forward_sync
// @property
// def grad_context(self):
// """The corresponding WhileContext for gradient."""
// return self._grad_context
// @property
// def grad_index(self):
// """The loop index of backprop loop."""
// return self._grad_index
// @property
// def grad_sync(self):
// """A control trigger node for synchronization in the grad loop.
// One main use is to keep the pop ops of a stack executed in the
// iteration order.
// """
// if self._grad_sync is None:
// with ops.control_dependencies(None):
// self._grad_sync = control_trigger(name="b_sync")
// self._grad_sync._set_control_flow_context(self._grad_context)
// self._grad_index.op._add_control_input(self._grad_sync)
// if self._grad_context.outer_context:
// self._grad_context.outer_context.AddInnerOp(self._grad_sync)
// return self._grad_sync
// @property
// def history_map(self):
// """The map that records all the tensors needed for backprop."""
// return self._history_map
// @property
// def switch_map(self):
// """The map that records all the Switch ops for the while loop."""
// return self._switch_map
// @property
// def unused_exits(self):
// """The list of "unused" exits."""
// return self._unused_exits
// @property
// def deferred_exits(self):
// """The list of "deferred" exits."""
// return self._deferred_exits
// @property
// def forward_loop_exits(self):
// """The list of exits of the forward loop."""
// return self._forward_loop_exits
// @property
// def pending_exits_count(self):
// """The number of exits we expect to see but haven't."""
// return self._pending_exits_count
// @pending_exits_count.setter
// def pending_exits_count(self, cnt):
// """Set the pending count to cnt."""
// self._pending_exits_count = cnt
/// <summary>
/// Add an accumulator for each forward tensor that is needed in backprop.
///
/// This is added to the forward loop at the first time when a tensor
/// in the forward loop is used by backprop gradient computation loop.
/// We create an accumulator that accumulates the value of tensor at each
/// iteration. Called in the control flow context where gradients() is called.
///
/// The pseudocode is:
/// ```
/// acc = stack();
/// while (_pivot) {
/// acc = stack_push(acc, value);
/// }
/// ```
///
/// We make sure that the stack push op in one iteration is executed before
/// next iteration. This is achieved by adding a control edge from
/// `forward_index.op.inputs[0].op` to the push op, and another control
/// edge from the push op to either `forward_index.op` or `forward_sync`.
/// </summary>
/// <param name="value"> The source tensor in forward that is to be accumulated.</param>
/// <param name="dead_branch"> True iff the tensor is on a dead branch of a cond.</param>
/// <returns>The stack that contains the accumulated history of the tensor.</returns>
public Tensor AddForwardAccumulator(Tensor value, bool dead_branch = false)
{
throw new NotImplementedException("AddForwardAccumulator");
// # curr_ctxt is the context that tf.gradients was called in.
// with self._forward_index.graph.as_default():
// curr_ctxt = ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access
// with ops.control_dependencies(None):
// if curr_ctxt:
// curr_ctxt.Enter()
// with ops.colocate_with(value):
// # We only need to pass maximum_iterations to the stack if
// # we're inside an XLA context.
// if not util.IsInXLAContext(value.op):
// max_size = constant_op.constant(-1, dtypes.int32)
// else:
// max_size = GetMaxSizeFromNestedMaximumIterations(
// value, self.forward_context)
// acc = gen_data_flow_ops.stack_v2(
// max_size=max_size, elem_type=value.dtype.base_dtype, name="f_acc")
// if curr_ctxt:
// curr_ctxt.Exit()
// # Make acc available in the forward context.
// enter_acc = self.forward_context.AddValue(acc)
// # Add the stack_push op in the context of value.op.
// swap_enabled = self.forward_context.swap_memory
// value_ctxt = util.GetOutputContext(value.op)
// if value_ctxt == self.forward_context:
// # value is not nested in the forward context.
// self.forward_context.Enter()
// push = gen_data_flow_ops.stack_push_v2(
// enter_acc, value, swap_memory=swap_enabled)
// self.forward_context.Exit()
// # Protect stack push and order it before forward_index.
// self.forward_index.op._add_control_input(push.op)
// else:
// # value is in a cond context within the forward context.
// if not isinstance(value_ctxt, CondContext):
// raise TypeError("value_ctxt is not a CondContext: %s" % value_ctxt)
// if dead_branch:
// # The special case for creating a zero tensor for a dead
// # branch of a switch. See ControlFlowState.ZerosLike().
// value_ctxt.outer_context.Enter()
// push = gen_data_flow_ops.stack_push_v2(
// enter_acc, value, swap_memory=swap_enabled)
// value_ctxt.outer_context.Exit()
// push.op._set_control_flow_context(value_ctxt)
// else:
// value_ctxt.Enter()
// push = gen_data_flow_ops.stack_push_v2(
// enter_acc, value, swap_memory=swap_enabled)
// value_ctxt.Exit()
// # Protect stack push and order it before forward_sync.
// self.forward_sync._add_control_input(push.op)
// # Order stack push after the successor of forward_index
// add_op = self.forward_index.op.inputs[0].op
// push.op._add_control_input(add_op)
// return acc
}
// """Add the getter for an accumulated value in the grad context.
//
// This is added to the backprop loop. Called in the grad context to
// get the value of an accumulated value. The stack pop op must be guarded
// by the pred of the controlling cond.
//
// Args:
// history_value: The history (a stack) of a value.
// value: The value that is pushed onto the stack.
// dead_branch: True iff the tensor is on a dead branch of a cond.
//
// Returns:
// The current value (the top of the stack).
// """
public Tensor AddBackpropAccumulatedValue(Tensor history_value, Tensor value, bool dead_branch= false)
{
throw new NotImplementedException();
// history_ctxt = history_value.op._get_control_flow_context()
// # Find the cond context that controls history_value if any.
// cond_ctxt = None
// value_ctxt = value.op._get_control_flow_context()
// while value_ctxt and value_ctxt != history_ctxt:
// if isinstance(value_ctxt, CondContext):
// cond_ctxt = value_ctxt
// break
// value_ctxt = value_ctxt.outer_context
// with ops.control_dependencies(None):
// self.grad_context.Enter()
// if cond_ctxt:
// # Guard stack pop with a switch if it is controlled by a cond.
// grad_state = self
// pred = None
// while pred is None and grad_state:
// pred = grad_state.history_map.get(cond_ctxt.pred.name)
// grad_state = grad_state.outer_grad_state
// if pred is None:
// pred = cond_ctxt.pred
// branch = (1 - cond_ctxt.branch) if dead_branch else cond_ctxt.branch
// history_value = _SwitchRefOrTensor(history_value, pred)[branch]
// pop = gen_data_flow_ops.stack_pop_v2(history_value,
// value.dtype.base_dtype)
// pop.set_shape(value.get_shape())
// self.grad_context.Exit()
// parallel_iterations = self.grad_context.parallel_iterations
// if parallel_iterations > 1:
// # All pops are ordered after pivot_for_body and before grad_sync.
// self.grad_sync._add_control_input(pop.op)
// return pop
}
// def GetRealValue(self, value):
// """Get the real value of `value`.
// If backprop "uses" a value produced by forward inference, an accumulator
// is added in the forward loop to accumulate its values. We use the
// accumulated value. This method must be called in the grad loop context.
// `value` must be in forward and needed for backprop.
// Args:
// value: A tensor to be captured.
// Returns:
// The same tensor obtained from the saved history.
// """
// assert value.op.type not in ["Variable", "VariableV2"]
// real_value = self._history_map.get(value.name)
// if real_value is None:
// cur_value = value
// cur_grad_state = self
// while True:
// enter_op = util.GetLoopConstantEnter(cur_value)
// if enter_op:
// # Special case: cur_value comes from a constant Enter node.
// cur_value = enter_op.inputs[0]
// cur_grad_state = cur_grad_state.outer_grad_state
// if cur_grad_state is None:
// # We are now outside all nested loops for this gradient(),
// # so `value` is a loop invariant and there is no need to
// # save the history of value. Just make cur_value to enter
// # the right control flow context.
// real_value = self._grad_context.AddValue(cur_value)
// break
// elif constant_op.is_constant(cur_value):
// # If the value to be forwarded is a constant, clone the constant in
// # the gradient loop rather than using a stack.
// # TODO(phawkins): consider hoisting the constant out of the loop
// # instead.
// real_value = constant_op.constant(
// tensor_util.constant_value(cur_value), dtype=cur_value.dtype)
// break
// else:
// # Record the history of this value in forward_ctxt.
// self._grad_context.Exit()
// history_value = cur_grad_state.AddForwardAccumulator(cur_value)
// self._grad_context.Enter()
// break
// if real_value is None:
// # Add the stack pop op in the grad context.
// real_value = cur_grad_state.AddBackpropAccumulatedValue(
// history_value, cur_value)
// if cur_grad_state != self:
// real_value = self._grad_context.AddValue(real_value)
// self._history_map[value.name] = real_value
// return real_value
}
}

+ 11
- 9
src/TensorFlowNET.Core/Operations/ControlFlows/IControlFlowContext.cs View File

@@ -4,13 +4,15 @@ using System.Text;

namespace Tensorflow
{
public interface IControlFlowContext
{
void AddOp(Operation op);
IControlFlowContext outer_context { get; }
HashSet<string> values { get; }
Tensor AddValue(Tensor val);
void AddInnerOp(Operation resultOp);
object to_proto();
}
// henon: this was too much trouble. there is no value just cost to use an interface here.
//public interface IControlFlowContext
//{
// void AddOp(Operation op);
// IControlFlowContext outer_context { get; }
// HashSet<string> values { get; }
// Tensor pivot { get; }
// Tensor AddValue(Tensor val);
// void AddInnerOp(Operation resultOp);
// object to_proto();
//}
}

+ 15
- 0
src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs View File

@@ -1,11 +1,26 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Operations.ControlFlows;

namespace Tensorflow.Operations
{
public class WhileContext : ControlFlowContext
{
private bool _back_prop=true;

private GradLoopState _grad_state =null;

public override WhileContext GetWhileContext()
{
return this;
}


public override GradLoopState grad_state => _grad_state;

public override bool back_prop => _back_prop;

public static WhileContext from_proto(object proto)
{
throw new NotImplementedException();


+ 3
- 3
src/TensorFlowNET.Core/Operations/Operation.Control.cs View File

@@ -7,7 +7,7 @@ namespace Tensorflow
{
public partial class Operation
{
private IControlFlowContext _control_flow_context;
private ControlFlowContext _control_flow_context;
/// <summary>
/// Add this op to its control flow context.
@@ -39,12 +39,12 @@ namespace Tensorflow
_add_control_input(op);
}

public void _set_control_flow_context(IControlFlowContext ctx)
public void _set_control_flow_context(ControlFlowContext ctx)
{
_control_flow_context = ctx;
}

public IControlFlowContext _get_control_flow_context()
public ControlFlowContext _get_control_flow_context()
{
return _control_flow_context;
}


+ 12
- 7
src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs View File

@@ -3,6 +3,7 @@ using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Operations;
using Tensorflow.Operations.ControlFlows;
using util = Tensorflow.control_flow_util;

namespace Tensorflow
@@ -93,9 +94,9 @@ namespace Tensorflow
/// <param name="between_op_list"></param>
/// <param name="between_ops"></param>
/// <param name="colocate_gradients_with_ops"></param>
public static object MaybeCreateControlFlowState(List<Operation> between_op_list, List<Operation> between_ops, bool colocate_gradients_with_ops)
public static ControlFlowState MaybeCreateControlFlowState(List<Operation> between_op_list, List<Operation> between_ops, bool colocate_gradients_with_ops)
{
object loop_state = null;
ControlFlowState loop_state = null;

foreach (var op in between_op_list)
{
@@ -103,7 +104,7 @@ namespace Tensorflow
{
if(loop_state == null)
{
// loop_state = ControlFlowState();
loop_state = new ControlFlowState();
}
}
}
@@ -207,7 +208,7 @@ namespace Tensorflow
/// `(output_false, output_true)`: If `pred` is true, data will be forwarded to
/// `output_true`, otherwise it goes to `output_false`.
/// </returns>
public static (Tensor, Tensor) _SwitchRefOrTensor(Tensor data, Tensor pred, string name = "Switch")
public static Tensor[] _SwitchRefOrTensor(Tensor data, Tensor pred, string name = "Switch")
{
data = ops.convert_to_tensor_or_indexed_slices(data, name: "data");
// NOTE(vrv): ops.colocate_with(data, ignore_existing=True) below
@@ -298,7 +299,9 @@ namespace Tensorflow
*/

// Add the Switch to the graph.
var (p_2, p_1) = @switch(pred, pred);
var switch_result= @switch(pred, pred);
var p_2=switch_result[0];
var p_1 = switch_result[1];
var pivot_1 = array_ops.identity(p_1, name: "switch_t");
var pivot_2 = array_ops.identity(p_2, name: "switch_f");
pred = array_ops.identity(pred, name: "pred_id");
@@ -379,7 +382,9 @@ namespace Tensorflow
return with(ops.name_scope(name, "cond", new { pred }), delegate
{
// Add the Switch to the graph.
var (p_2, p_1) = @switch(pred, pred);
var switch_result = @switch(pred, pred);
var p_2 = switch_result[0];
var p_1 = switch_result[1];
var pivot_1 = array_ops.identity(p_1, name: "switch_t");
var pivot_2 = array_ops.identity(p_2, name: "switch_f");
pred = array_ops.identity(pred, name: "pred_id");
@@ -460,7 +465,7 @@ namespace Tensorflow
/// <param name="pred"></param>
/// <param name="dtype"></param>
/// <param name="name"></param>
public static (Tensor, Tensor) @switch(Tensor data,
public static Tensor[] @switch(Tensor data,
Tensor pred,
TF_DataType dtype = TF_DataType.DtInvalid,
string name = null)


+ 1
- 1
src/TensorFlowNET.Core/Operations/control_flow_util.py.cs View File

@@ -30,7 +30,7 @@ namespace Tensorflow
/// <summary>
/// Return the control flow context for the output of an op.
/// </summary>
public static IControlFlowContext GetOutputContext(Operation op)
public static ControlFlowContext GetOutputContext(Operation op)
{
var ctxt = op._get_control_flow_context();
// Exit nodes usually have a control flow context, except in the case where the


+ 2
- 2
src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs View File

@@ -33,14 +33,14 @@ namespace Tensorflow
/// output_false: A `Tensor`. Has the same type as `data`.
/// output_true: A `Tensor`. Has the same type as `data`.
/// </returns>
public static (Tensor, Tensor) @switch(Tensor data, Tensor pred, string name = null)
public static Tensor[] @switch(Tensor data, Tensor pred, string name = null)
{
var _op = _op_def_lib._apply_op_helper("Switch", name, new { data, pred });
var _inputs_flat = _op.inputs;
var _attrs = ("T", _op.get_attr("T"));
// TODO: missing original code
//_execute.record_gradient("Switch", _inputs_flat, _attrs, _result, name);
return (_op.outputs[0], _op.outputs[1]);
return new []{_op.outputs[0], _op.outputs[1]};
}

public static (Tensor, Tensor) merge(Tensor[] inputs, string name = null)


Loading…
Cancel
Save