@@ -1,12 +1,79 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
using Tensorflow.Operations; | |||
namespace Tensorflow.Gradients | |||
{ | |||
/// <summary> | |||
/// Gradients for operators defined in control_flow_ops.py.cs | |||
/// </summary> | |||
public class control_flow_grad | |||
{ | |||
/// <summary> | |||
/// Gradients for a Switch op is calculated using a Merge op. | |||
/// | |||
/// If the switch is a loop switch, it will be visited twice. We create | |||
/// the merge on the first visit, and update the other input of the merge | |||
/// on the second visit. A next_iteration is also added on second visit. | |||
/// </summary> | |||
/// <returns></returns> | |||
public Tensor[] _SwitchGrad(Tensor op, Tensor[] grads) | |||
{ | |||
throw new NotImplementedException("_SwitchGrad"); | |||
//graph = ops.get_default_graph() | |||
//# pylint: disable=protected-access | |||
//op_ctxt = op._get_control_flow_context() | |||
//grad_ctxt = graph._get_control_flow_context() | |||
//# pylint: enable=protected-access | |||
//if isinstance(op_ctxt, WhileContext): | |||
// merge_grad = grad_ctxt.grad_state.switch_map.get(op) | |||
// if merge_grad is not None: | |||
// # This is the second time this Switch is visited. It comes from | |||
// # the non-exit branch of the Switch, so update the second input | |||
// # to the Merge. | |||
// # TODO(yuanbyu): Perform shape inference with this new input. | |||
// if grad[1] is not None: | |||
// # pylint: disable=protected-access | |||
// control_flow_ops._AddNextAndBackEdge(merge_grad, grad[1], | |||
// enforce_shape_invariant=False) | |||
// # pylint: enable=protected-access | |||
// return None, None | |||
// elif grad[0] is not None: | |||
// # This is the first time this Switch is visited. It comes from | |||
// # the Exit branch, which is grad[0]. grad[1] is empty at this point. | |||
// # Use grad[0] for both inputs to merge for now, but update the second | |||
// # input of merge when we see this Switch the second time. | |||
// merge_grad = merge([grad[0], grad[0]], name="b_switch")[0] | |||
// grad_ctxt.grad_state.switch_map[op] = merge_grad | |||
// return merge_grad, None | |||
// else: | |||
// # This is the first time this Switch is visited. It comes from the | |||
// # Identity branch. Such a Switch has `None` gradient for the Exit branch, | |||
// # meaning the output is not differentiable. | |||
// return None, None | |||
//elif isinstance(op_ctxt, CondContext): | |||
// zero_grad = grad[1 - op_ctxt.branch] | |||
// # At this point, we have created zero_grad guarded by the right switch. | |||
// # Unfortunately, we may still get None here for not trainable data types. | |||
// if zero_grad is None: | |||
// # For resource variables we get None always on the other branch, so bypass | |||
// # this. | |||
// if op.inputs[0].dtype == dtypes.resource: | |||
// return merge( | |||
// [grad[op_ctxt.branch]] * 2, name="cond_resource_grad")[0], None | |||
// return None, None | |||
// return merge(grad, name="cond_grad")[0], None | |||
//else: | |||
// false_grad = switch(grad[0], op.inputs[1])[0] | |||
// true_grad = switch(grad[1], op.inputs[1])[1] | |||
// return merge([false_grad, true_grad])[0], None | |||
} | |||
/// <summary> | |||
/// Gradients for a Merge op are calculated using a Switch op. | |||
/// </summary> | |||
public static Tensor[] _MergeGrad(Operation op, Tensor[] grads) | |||
{ | |||
var grad = grads[0]; | |||
@@ -14,10 +81,164 @@ namespace Tensorflow.Gradients | |||
var input_op = op.inputs[0].op; | |||
var graph = ops.get_default_graph(); | |||
var op_ctxt = control_flow_util.GetOutputContext(input_op); | |||
var pred = (op_ctxt as CondContext).pred; | |||
var grad_ctxt = graph._get_control_flow_context(); | |||
switch (op_ctxt) | |||
{ | |||
case WhileContext cwhile: | |||
{ | |||
return control_flow_ops._SwitchRefOrTensor(grad, grad_ctxt.pivot); | |||
} | |||
case CondContext ccond: | |||
{ | |||
var pred = ccond.pred; | |||
if (grad_ctxt != null && grad_ctxt.grad_state != null) | |||
{ | |||
//# This Merge node is part of a cond within a loop. | |||
//# The backprop needs to have the value of this predicate for every | |||
//# iteration. So we must have its values accumulated in the forward, and | |||
//# use the accumulated values as the predicate for this backprop switch. | |||
var grad_state = grad_ctxt.grad_state; | |||
var real_pred = grad_state.history_map[pred.name] as Tensor; | |||
if (real_pred == null) | |||
{ | |||
//# Remember the value of pred for every iteration. | |||
grad_ctxt = grad_state.grad_context; | |||
grad_ctxt.Exit(); | |||
var history_pred = grad_state.AddForwardAccumulator(pred); | |||
grad_ctxt.Enter(); | |||
//# Add the stack pop op. If pred.op is in a (outer) CondContext, | |||
//# the stack pop will be guarded with a switch. | |||
real_pred = grad_state.AddBackpropAccumulatedValue(history_pred, pred); | |||
grad_state.history_map[pred.name] = real_pred; | |||
} | |||
pred = real_pred; | |||
} | |||
var results = control_flow_ops._SwitchRefOrTensor(grad, pred, name: "cond_grad"); | |||
return results; | |||
} | |||
default: | |||
{ | |||
var num_inputs = op.inputs.Length; | |||
var cond = new Tensor[num_inputs]; | |||
for (int i = 0; i < num_inputs; i++) | |||
cond[i] = math_ops.equal(op.outputs[1], i); | |||
var result = cond.Select(t => control_flow_ops._SwitchRefOrTensor(grad, t)[1]).ToArray(); | |||
return result; | |||
} | |||
} | |||
var results = control_flow_ops._SwitchRefOrTensor(grad, pred, name: "cond_grad"); | |||
return new Tensor[] { results.Item1, results.Item2 }; | |||
} | |||
} | |||
public Tensor[] _RefMergeGrad(Operation op, Tensor[] grads) | |||
{ | |||
return _MergeGrad(op, grads); | |||
} | |||
/// <summary> | |||
/// Gradients for an exit op are calculated using an Enter op. | |||
/// </summary> | |||
public Tensor[] _ExitGrad(Operation op, Tensor[] grads) | |||
{ | |||
throw new NotImplementedException("_ExitGrad"); | |||
// graph = ops.get_default_graph() | |||
//# pylint: disable=protected-access | |||
// op_ctxt = op._get_control_flow_context() | |||
// grad_ctxt = graph._get_control_flow_context() | |||
// # pylint: enable=protected-access | |||
// if not grad_ctxt.back_prop: | |||
// # The flag `back_prop` is set by users to suppress gradient | |||
// # computation for this loop. If the attribute `back_prop` is false, | |||
// # no gradient computation. | |||
// return None | |||
// if op_ctxt.grad_state: | |||
// raise TypeError("Second-order gradient for while loops not supported.") | |||
// if isinstance(grad, ops.Tensor) : | |||
// grad_ctxt.AddName(grad.name) | |||
// else: | |||
// if not isinstance(grad, (ops.IndexedSlices, sparse_tensor.SparseTensor)): | |||
// raise TypeError("Type %s not supported" % type(grad)) | |||
// grad_ctxt.AddName(grad.values.name) | |||
// grad_ctxt.AddName(grad.indices.name) | |||
// dense_shape = grad.dense_shape | |||
// if dense_shape is not None: | |||
// grad_ctxt.AddName(dense_shape.name) | |||
// grad_ctxt.Enter() | |||
// # pylint: disable=protected-access | |||
// result = control_flow_ops._Enter( | |||
// grad, grad_ctxt.name, is_constant=False, | |||
// parallel_iterations=grad_ctxt.parallel_iterations, | |||
// name="b_exit") | |||
// # pylint: enable=protected-access | |||
// grad_ctxt.loop_enters.append(result) | |||
// grad_ctxt.Exit() | |||
// return result | |||
} | |||
/// <summary> | |||
/// A forward next_iteration is translated into a backprop identity. | |||
/// | |||
/// Note that the backprop next_iteration is added in switch grad. | |||
/// </summary> | |||
public (object, Tensor[]) _NextIterationGrad(object _, Tensor[] grad) | |||
{ | |||
return (_, grad); | |||
} | |||
public (object, Tensor[]) _RefNextIterationGrad(object _, Tensor[] grad) | |||
{ | |||
return (_, grad); | |||
} | |||
/// <summary> | |||
/// Gradients for an Enter are calculated using an Exit op. | |||
/// | |||
/// For loop variables, grad is the gradient so just add an exit. | |||
/// For loop invariants, we need to add an accumulator loop. | |||
/// </summary> | |||
public (object, Tensor[]) _EnterGrad(Tensor op, Tensor[] grad) | |||
{ | |||
throw new NotImplementedException("_EnterGrad"); | |||
// graph = ops.get_default_graph() | |||
//# pylint: disable=protected-access | |||
// grad_ctxt = graph._get_control_flow_context() | |||
// # pylint: enable=protected-access | |||
// if not grad_ctxt.back_prop: | |||
// # Skip gradient computation, if the attribute `back_prop` is false. | |||
// return grad | |||
// if grad_ctxt.grad_state is None: | |||
// # Pass the gradient through if we are not in a gradient while context. | |||
// return grad | |||
// if op.get_attr("is_constant"): | |||
// # Add a gradient accumulator for each loop invariant. | |||
// if isinstance(grad, ops.Tensor) : | |||
// result = grad_ctxt.AddBackpropAccumulator(op, grad) | |||
// elif isinstance(grad, ops.IndexedSlices) : | |||
// result = grad_ctxt.AddBackpropIndexedSlicesAccumulator(op, grad) | |||
// else: | |||
// # TODO(yuanbyu, lukasr): Add support for SparseTensor. | |||
// raise TypeError("Type %s not supported" % type(grad)) | |||
// else: | |||
// result = exit(grad) | |||
// grad_ctxt.loop_exits.append(result) | |||
// grad_ctxt.ExitResult([result]) | |||
// return result | |||
} | |||
public (object, Tensor[]) _RefEnterGrad(Tensor op, Tensor[] grad) | |||
{ | |||
return _EnterGrad(op, grad); | |||
} | |||
/// <summary> | |||
/// Stop backprop for the predicate of a while loop. | |||
/// </summary> | |||
public object _LoopCondGrad(object _) | |||
{ | |||
return null; | |||
} | |||
} | |||
} |
@@ -3,13 +3,14 @@ using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
using Tensorflow.Eager; | |||
using Tensorflow.Operations; | |||
namespace Tensorflow | |||
{ | |||
public partial class Graph | |||
{ | |||
// Current control flow context. It could be either CondContext or WhileContext | |||
public IControlFlowContext _control_flow_context; | |||
public ControlFlowContext _control_flow_context; | |||
// represents the nested with(...) statements | |||
public List<_ControlDependenciesController> _control_dependencies_stack { get; set; } = new List<_ControlDependenciesController>(); | |||
@@ -97,7 +98,7 @@ namespace Tensorflow | |||
/// Returns the current control flow context. | |||
/// </summary> | |||
/// <returns>A context object.</returns> | |||
public IControlFlowContext _get_control_flow_context() | |||
public ControlFlowContext _get_control_flow_context() | |||
{ | |||
return _control_flow_context; | |||
} | |||
@@ -106,7 +107,7 @@ namespace Tensorflow | |||
/// Sets the current control flow context. | |||
/// </summary> | |||
/// <param name="ctx">a context object.</param> | |||
public void _set_control_flow_context(IControlFlowContext ctx) | |||
public void _set_control_flow_context(ControlFlowContext ctx) | |||
{ | |||
_control_flow_context = ctx; | |||
} | |||
@@ -2,6 +2,7 @@ | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Eager; | |||
using Tensorflow.Operations; | |||
namespace Tensorflow | |||
{ | |||
@@ -15,7 +16,7 @@ namespace Tensorflow | |||
private List<ITensorOrOperation> _seen_nodes; | |||
private List<_ControlDependenciesController> _old_stack; | |||
private bool _new_stack; | |||
private IControlFlowContext _old_control_flow_context; | |||
private ControlFlowContext _old_control_flow_context; | |||
public ITensorOrOperation[] control_inputs => _control_inputs_val.ToArray(); | |||
@@ -2,6 +2,7 @@ | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
using Tensorflow.Operations.ControlFlows; | |||
namespace Tensorflow.Operations | |||
{ | |||
@@ -107,8 +108,8 @@ namespace Tensorflow.Operations | |||
with(ops.control_dependencies(null), ctrl => | |||
{ | |||
var (r0, r1) = control_flow_ops._SwitchRefOrTensor(result, _pred); | |||
result = new[] { r0, r1 }[_branch]; | |||
var results = control_flow_ops._SwitchRefOrTensor(result, _pred); | |||
result = results[_branch]; | |||
if (_outer_context != null) | |||
_outer_context.AddInnerOp(result.op); | |||
}); | |||
@@ -118,7 +119,7 @@ namespace Tensorflow.Operations | |||
// Mark Switch output as seen by this context and any outer contexts, | |||
// just like what we do for normal op outputs in _AddOpInternal() below. | |||
IControlFlowContext ctxt = this; | |||
ControlFlowContext ctxt = this; | |||
while (ctxt != null) | |||
{ | |||
ctxt.values.Add(result.name); | |||
@@ -223,8 +224,8 @@ namespace Tensorflow.Operations | |||
_values.Add(real_val.name); | |||
_external_values[real_val.name] = real_val; | |||
} | |||
var (t0, t1) = control_flow_ops._SwitchRefOrTensor(real_val, _pred); | |||
real_val = new[] {t0, t1}[_branch]; | |||
var results = control_flow_ops._SwitchRefOrTensor(real_val, _pred); | |||
real_val = results[_branch]; | |||
_external_values[val.name] = real_val; | |||
} | |||
else | |||
@@ -238,8 +239,8 @@ namespace Tensorflow.Operations | |||
return real_val; | |||
} | |||
protected override void _AddOpInternal(Operation op) | |||
{ | |||
protected override void _AddOpInternal(Operation op) | |||
{ | |||
if (op.inputs.Length == 0) | |||
{ | |||
//If we're in a while loop, remove any control inputs from outside the | |||
@@ -282,11 +283,11 @@ namespace Tensorflow.Operations | |||
// TODO: implement below code dependencies | |||
//if (op.graph._is_function(op.type) || op.type == "SymbolicGradient") | |||
// op._add_control_input(_pivot.op); | |||
} | |||
// Mark op's outputs as seen by this context and any outer contexts. | |||
} | |||
// Mark op's outputs as seen by this context and any outer contexts. | |||
var output_names = op.outputs.Select(x => x.name).ToArray(); | |||
IControlFlowContext ctxt = this; | |||
ControlFlowContext ctxt = this; | |||
while (ctxt != null) | |||
{ | |||
foreach (var name in output_names) | |||
@@ -298,9 +299,31 @@ namespace Tensorflow.Operations | |||
op.graph.prevent_fetching(op); | |||
if (_outer_context != null) | |||
_outer_context.AddInnerOp(op); | |||
} | |||
_outer_context.AddInnerOp(op); | |||
} | |||
public override GradLoopState grad_state | |||
{ | |||
get | |||
{ | |||
var whc = GetWhileContext(); | |||
if (whc != null) | |||
return whc.grad_state; | |||
return null; | |||
} | |||
} | |||
public override bool back_prop | |||
{ | |||
get | |||
{ | |||
var whc = GetWhileContext(); | |||
if (whc != null) | |||
return whc.back_prop; | |||
return false; | |||
} | |||
} | |||
public CondContextDef to_proto(string export_scope) | |||
{ | |||
throw new NotImplementedException(); | |||
@@ -2,6 +2,7 @@ | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
using Tensorflow.Operations.ControlFlows; | |||
namespace Tensorflow.Operations | |||
{ | |||
@@ -22,21 +23,25 @@ namespace Tensorflow.Operations | |||
/// 4. A ControlFlowContext has _context_stack. | |||
/// Pushed and popped by ctxt.Enter() and ctxt.Exit() | |||
/// </summary> | |||
public abstract class ControlFlowContext : Python, IPython, IControlFlowContext | |||
public abstract class ControlFlowContext : Python, IPython | |||
{ | |||
/// <summary> | |||
/// The predicate tensor in this branch | |||
/// </summary> | |||
protected Tensor _pivot; | |||
public Tensor pivot | |||
{ | |||
get => _pivot; | |||
} | |||
protected Stack<IControlFlowContext> _context_stack; | |||
protected IControlFlowContext _outer_context; | |||
protected Stack<ControlFlowContext> _context_stack; | |||
protected ControlFlowContext _outer_context; | |||
protected Dictionary<string, ITensorOrOperation> _external_values; | |||
public ControlFlowContext() | |||
{ | |||
_context_stack = new Stack<IControlFlowContext>(); | |||
_context_stack = new Stack<ControlFlowContext>(); | |||
} | |||
public string name { get => _name; } | |||
@@ -111,8 +116,13 @@ namespace Tensorflow.Operations | |||
_AddOpInternal(op); | |||
} | |||
public IControlFlowContext outer_context { get { return _outer_context; } } | |||
public ControlFlowContext outer_context { get { return _outer_context; } } | |||
public HashSet<string> values => _values; | |||
public virtual GradLoopState grad_state => throw new NotImplementedException("abstract method"); | |||
public virtual bool back_prop => throw new NotImplementedException("abstract method"); | |||
public virtual Tensor AddValue(Tensor val) | |||
{ | |||
// to be overridden | |||
@@ -147,7 +157,7 @@ namespace Tensorflow.Operations | |||
/// <summary> | |||
/// Returns true if `maybe_containing_ctxt` is or contains `ctxt`. | |||
/// </summary> | |||
public static bool IsContainingContext(IControlFlowContext ctxt, ControlFlowContext maybe_containing_ctxt) | |||
public static bool IsContainingContext(ControlFlowContext ctxt, ControlFlowContext maybe_containing_ctxt) | |||
{ | |||
while (ctxt != maybe_containing_ctxt) | |||
{ | |||
@@ -164,6 +174,16 @@ namespace Tensorflow.Operations | |||
var internal_control_inputs = op.control_inputs; | |||
} | |||
/// <summary> | |||
/// Return the while context containing this context | |||
/// </summary> | |||
public virtual WhileContext GetWhileContext() | |||
{ | |||
if (_outer_context != null) | |||
return _outer_context.GetWhileContext(); | |||
return null; | |||
} | |||
public object to_proto() | |||
{ | |||
throw new NotImplementedException(); | |||
@@ -173,5 +193,6 @@ namespace Tensorflow.Operations | |||
public void Dispose() | |||
{ | |||
} | |||
} | |||
} |
@@ -0,0 +1,277 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Operations.ControlFlows | |||
{ | |||
/// <summary> | |||
/// Maintain the mapping from the loops to their grad states. | |||
/// </summary> | |||
public class ControlFlowState | |||
{ | |||
//class ControlFlowState(object): | |||
// """Maintain the mapping from the loops to their grad states.""" | |||
// def __init__(self): | |||
// self._map = {} # maps forward loop context to GradLoopState | |||
// def GetGradState(self, op, before): | |||
// """Return the grad state for this op if it's in a forward loop context.""" | |||
// if before and util.IsLoopExit(op): | |||
// forward_ctxt = op._get_control_flow_context() | |||
// forward_ctxt = forward_ctxt.outer_context | |||
// if forward_ctxt: | |||
// forward_ctxt = forward_ctxt.GetWhileContext() | |||
// else: | |||
// forward_ctxt = _GetWhileContext(op) | |||
// if forward_ctxt: | |||
// return self._map.get(forward_ctxt) | |||
// return None | |||
// def ProcessUnusedLoopExits(self, pending_count, to_ops_set): | |||
// """Process all the "unused" loop exits. | |||
// The "unused" exits of the loops are added to `unused_exits`. An exit is | |||
// unused if its pending_count is 0. If there is an exit with real gradient, | |||
// all these deferred exits will enter the backprop loop with zero gradient. | |||
// Otherwise, they will enter the backprop loop with None. As an example, | |||
// people often write: | |||
// ```python | |||
// v1, _ = tf.while_loop(p, b, [x1, x2]) | |||
// result = gradients(v1, x1) | |||
// ``` | |||
// The exit node for x2 is not included by the betweenness analysis. But we | |||
// need to backprop x2 if x2 is involved in computing v1. | |||
// Args: | |||
// pending_count: The number of backprop inputs for every op. | |||
// to_ops_set: The set of ops for ys in gradients(ys, xs) | |||
// Returns: | |||
// The set of unused loop exits that we know at this point we need | |||
// to backprop. | |||
// """ | |||
// loop_exits = [] | |||
// for grad_state in self._map.values(): | |||
// for y in grad_state.forward_loop_exits: | |||
// if pending_count[y.op] == 0: | |||
// grad_state.pending_exits_count -= 1 | |||
// if y.op not in to_ops_set: | |||
// grad_state.unused_exits.append(y) | |||
// if grad_state.pending_exits_count == 0: | |||
// loop_exits.extend(grad_state.unused_exits) | |||
// # Need to include Enters in backprop for higher-order gradients. | |||
// for y in grad_state.forward_context.loop_enters: | |||
// if pending_count[y.op] == 0: | |||
// pending_count[y.op] = 1 | |||
// return loop_exits | |||
// def EnterGradWhileContext(self, op, before): | |||
// """Enter the WhileContext for gradient computation.""" | |||
// grad_state = self.GetGradState(op, before) | |||
// if grad_state: | |||
// grad_state.grad_context.Enter() | |||
// def ExitGradWhileContext(self, op, before): | |||
// """Exit the WhileContext for gradient computation.""" | |||
// grad_state = self.GetGradState(op, before) | |||
// if grad_state: | |||
// grad_state.grad_context.Exit() | |||
// def AddWhileContext(self, op, between_op_list, between_ops): | |||
// """Add the grad state for the while loop that op belongs to. | |||
// Note that op is an Exit, and this method must be called in | |||
// the control flow context where gradients() is called. | |||
// Note that this method modifies `between_op_list` and `between_ops`. | |||
// """ | |||
// forward_ctxt = _GetWhileContext(op) | |||
// grad_state = self._map.get(forward_ctxt) | |||
// if grad_state is None: | |||
// # This is a new while loop so create a grad state for it. | |||
// outer_forward_ctxt = forward_ctxt.outer_context | |||
// if outer_forward_ctxt: | |||
// outer_forward_ctxt = outer_forward_ctxt.GetWhileContext() | |||
// outer_grad_state = None | |||
// if outer_forward_ctxt: | |||
// outer_grad_state = self._map.get(outer_forward_ctxt) | |||
// grad_state = GradLoopState(forward_ctxt, outer_grad_state) | |||
// self._map[forward_ctxt] = grad_state | |||
// # We need to include all exits of a loop for backprop. | |||
// for loop_exit in grad_state.forward_loop_exits: | |||
// if loop_exit.op not in between_ops: | |||
// between_ops.add(loop_exit.op) | |||
// between_op_list.append(loop_exit.op) | |||
// def ZerosLikeForExit(self, val): | |||
// """Create zeros_like gradient for a loop exit. | |||
// If the result of a loop variable is not used but is involved in | |||
// computing the result of some needed loop variable, we create a | |||
// zero-valued tensor that is fed as gradient for the Exit node of that | |||
// loop variable. Note that val.op is an Exit, and this method must be | |||
// called in the control flow context where gradients() is called. | |||
// Args: | |||
// val: The output tensor of an Exit op. | |||
// Returns: | |||
// A zero tensor of the same shape of val. | |||
// """ | |||
// val_shape = val.get_shape() | |||
// forward_ctxt = val.op._get_control_flow_context() | |||
// outer_forward_ctxt = forward_ctxt.outer_context | |||
// if outer_forward_ctxt: | |||
// outer_forward_ctxt = outer_forward_ctxt.GetWhileContext() | |||
// outer_grad_state = None | |||
// if outer_forward_ctxt: | |||
// outer_grad_state = self._map.get(outer_forward_ctxt) | |||
// if outer_grad_state: | |||
// # This is a nested loop. | |||
// if val_shape.is_fully_defined(): | |||
// # If the shape is known statically, just create a zero tensor | |||
// # with the right shape in the right context. | |||
// outer_grad_state.grad_context.Enter() | |||
// result = array_ops.zeros(val_shape.dims, val.dtype) | |||
// outer_grad_state.grad_context.Exit() | |||
// else: | |||
// # Only the shape of value is needed for backprop. | |||
// forward_ctxt.outer_context.Enter() | |||
// shape = array_ops.shape_internal(val, optimize=False) | |||
// forward_ctxt.outer_context.Exit() | |||
// # Save the shape to a stack. | |||
// history_shape = outer_grad_state.AddForwardAccumulator(shape) | |||
// # Get the shape back from the stack. | |||
// outer_grad_ctxt = outer_grad_state.grad_context | |||
// outer_grad_ctxt.Enter() | |||
// real_shape = outer_grad_state.AddBackpropAccumulatedValue( | |||
// history_shape, shape) | |||
// result = array_ops.zeros(real_shape, val.dtype) | |||
// outer_grad_ctxt.Exit() | |||
// else: | |||
// # This is not a nested loop. | |||
// if val_shape.is_fully_defined(): | |||
// # If the shape is known statically, just create a zero tensor | |||
// # with the right shape. | |||
// result = array_ops.zeros(val_shape.dims, val.dtype) | |||
// else: | |||
// result = array_ops.zeros_like(val, optimize=False) | |||
// return result | |||
// def ZerosLike(self, op, index): | |||
// """Create zeros_like for the specified output of an op. | |||
// If op is in a while loop that is part of gradients(), this method | |||
// must be called in its grad loop context. | |||
// Args: | |||
// op: A tensorflow operation. | |||
// index: the index for a specific output of the op. | |||
// Returns: | |||
// A zero tensor of the same shape of op.outputs[index]. | |||
// """ | |||
// if util.IsLoopSwitch(op): | |||
// return None | |||
// if op.graph._building_function: # pylint: disable=protected-access | |||
// # The optimization here is tricky to apply to functions | |||
// return array_ops.zeros_like(op.outputs[index]) | |||
// dead_branch = util.IsSwitch(op) | |||
// forward_ctxt = _GetWhileContext(op) | |||
// grad_state = self._map.get(forward_ctxt) | |||
// if grad_state is None: | |||
// # op is not in a while loop that is part of gradients(). | |||
// return ZerosLikeOutsideLoop(op, index) | |||
// op_ctxt = op._get_control_flow_context() | |||
// val = ops.convert_to_tensor(op.outputs[index], name="tensor") | |||
// shape = val.get_shape() | |||
// if shape.is_fully_defined(): | |||
// # If the shape is known statically, just create a zero tensor with | |||
// # the right shape in the grad loop context. | |||
// result = constant_op.constant(0, shape=shape.dims, dtype=val.dtype) | |||
// if dead_branch: | |||
// # op is a cond switch. Guard the zero tensor with a switch. | |||
// pred = grad_state.history_map.get(op_ctxt.pred.name) | |||
// branch = op_ctxt.branch | |||
// result = _SwitchRefOrTensor(result, pred)[1 - branch] | |||
// else: | |||
// # Unknown shape so keep a history of the shape at runtime. | |||
// if dead_branch: | |||
// # Need to add a special switch to guard the value. | |||
// pred = op_ctxt.pred | |||
// branch = op_ctxt.branch | |||
// op_ctxt.outer_context.Enter() | |||
// val = _SwitchRefOrTensor(op.inputs[0], pred)[1 - branch] | |||
// zeros_shape = array_ops.shape_internal(val, optimize=False) | |||
// op_ctxt.outer_context.Exit() | |||
// val.op._set_control_flow_context(op_ctxt) | |||
// zeros_shape.op._set_control_flow_context(op_ctxt) | |||
// else: | |||
// op_ctxt.Enter() | |||
// zeros_shape = array_ops.shape_internal(val, optimize=False) | |||
// op_ctxt.Exit() | |||
// # Add forward accumulator for shape. | |||
// grad_state.grad_context.Exit() | |||
// history_zeros_shape = grad_state.AddForwardAccumulator( | |||
// zeros_shape, dead_branch=dead_branch) | |||
// grad_state.grad_context.Enter() | |||
// # Create a zero tensor with the right shape. | |||
// shape = grad_state.AddBackpropAccumulatedValue(history_zeros_shape, | |||
// zeros_shape, dead_branch) | |||
// result = array_ops.zeros(shape, val.dtype) | |||
// return result | |||
// def PostProcessing(self): | |||
// """Perform postprocessing at the end of gradients(). | |||
// We have created the gradient graph at this point. So this function | |||
// can be used to perform any postprocessing on the gradient graph. | |||
// We currently perform the following postprocessing: | |||
// 1. Patch the gradient graph if the output of a loop variable | |||
// doesn't depend on its input. | |||
// """ | |||
// for _, grad_state in self._map.items(): | |||
// for _, b_merge in grad_state.switch_map.items(): | |||
// if b_merge.op.inputs[0] == b_merge.op.inputs[1]: | |||
// # The value of this loop variable at iteration i+1 doesn't | |||
// # depend on its value at iteration i. So use zeros as the | |||
// # gradients for all iterations > 0. | |||
// dtype = b_merge.op.inputs[0].dtype | |||
// shape = b_merge.op.inputs[0].get_shape() | |||
// # pylint: disable=protected-access | |||
// if shape.is_fully_defined(): | |||
// grad_state.grad_context.Enter() | |||
// # Create a zeros and use it for iterations > 0. | |||
// grad_val = constant_op.constant(0, dtype=dtype, shape=shape) | |||
// next_grad_val = _NextIteration(grad_val) | |||
// grad_state.grad_context.Exit() | |||
// else: | |||
// # Create a zeros in the outer grad context. | |||
// outer_grad_ctxt = grad_state.grad_context.outer_context | |||
// if outer_grad_ctxt: | |||
// outer_grad_ctxt.Enter() | |||
// enter_grad_op = b_merge.op.inputs[0].op | |||
// enter_grad = enter_grad_op.inputs[0] | |||
// grad_shape = array_ops.shape_internal(enter_grad, optimize=False) | |||
// grad_val = array_ops.zeros(grad_shape) | |||
// if outer_grad_ctxt: | |||
// outer_grad_ctxt.Exit() | |||
// # Use the zeros for iterations > 0. | |||
// grad_state.grad_context.Enter() | |||
// next_grad_val = _NextIteration(grad_val) | |||
// grad_state.grad_context.Exit() | |||
// b_merge.op._update_input(1, next_grad_val) | |||
// # pylint: enable=protected-access | |||
} | |||
} |
@@ -0,0 +1,398 @@ | |||
using System; | |||
using System.Collections; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Operations.ControlFlows | |||
{ | |||
public class GradLoopState | |||
{ | |||
//class GradLoopState(object): | |||
// """The state used for constructing the gradient graph for a while loop. | |||
// We create a GradLoopState for each while loop in forward and its | |||
// corresponding while loop in backprop. This gives us access to both | |||
// the forward and the backprop WhileContexts. | |||
// During the construction of gradient graph, any time when we detect | |||
// a forward value that is needed for backprop, we create a history | |||
// accumulator and add it to `history_map`. Any time when we backprop | |||
// a loop switch op (in _SwitchGrad), we add the grad merge op in | |||
// `switch_map`. | |||
// """ | |||
// def __init__(self, forward_ctxt, outer_grad_state): | |||
// # The grad loop state for the outer while loop. | |||
// self._outer_grad_state = None | |||
// # The while loop context for forward. | |||
// self._forward_context = None | |||
// # The loop counter added by AddForwardLoopCounter. It is the value | |||
// # of the loop counter for the next iteration. | |||
// self._forward_index = None | |||
// # A sync op for forward. | |||
// self._forward_sync = None | |||
// # The while loop context for backprop. | |||
private WhileContext _grad_context = null; | |||
public WhileContext grad_context => _grad_context; | |||
// # The loop counter added by AddBackpropLoopCounter. It is the value | |||
// # of the loop counter for the current iteration. | |||
// self._grad_index = None | |||
// # A sync op for backprop. | |||
// self._grad_sync = None | |||
// # Information needed by backprop. | |||
private Hashtable _history_map = new Hashtable(); | |||
public Hashtable history_map => _history_map; | |||
private Hashtable _switch_map = new Hashtable(); | |||
public Hashtable switch_map => _switch_map; | |||
// self._unused_exits = [] | |||
// self._deferred_exits = [] | |||
// self._forward_loop_exits = list(forward_ctxt.loop_exits) | |||
// self._pending_exits_count = len(forward_ctxt.loop_exits) | |||
// self._outer_grad_state = outer_grad_state | |||
// if outer_grad_state: | |||
// outer_forward_ctxt = outer_grad_state.forward_context | |||
// else: | |||
// if not hasattr(forward_ctxt, "outer_context"): | |||
// raise ValueError("Failed to call gradients on a while loop without" | |||
// "properly serializing graph via MetaGraphDef") | |||
// outer_forward_ctxt = forward_ctxt.outer_context | |||
// # Add the forward loop counter. | |||
// with forward_ctxt._graph.as_default(): # pylint: disable=protected-access | |||
// if outer_forward_ctxt: | |||
// outer_forward_ctxt.Enter() | |||
// cnt, forward_index = forward_ctxt.AddForwardLoopCounter(outer_grad_state) | |||
// if outer_forward_ctxt: | |||
// outer_forward_ctxt.Exit() | |||
// self._forward_context = forward_ctxt | |||
// self._forward_index = forward_index | |||
// # Add the backprop WhileContext, and the backprop loop counter. | |||
// if outer_grad_state: | |||
// # This is a nested loop. Remember the iteration counts for each | |||
// # execution of this inner loop. | |||
// outer_forward_ctxt.AddName(cnt.name) | |||
// history_cnt = outer_grad_state.AddForwardAccumulator(cnt) | |||
// outer_grad_ctxt = outer_grad_state.grad_context | |||
// outer_grad_ctxt.Enter() | |||
// self._grad_context = WhileContext( | |||
// maximum_iterations=forward_ctxt.maximum_iterations, | |||
// parallel_iterations=forward_ctxt.parallel_iterations, | |||
// back_prop=forward_ctxt.back_prop, | |||
// swap_memory=forward_ctxt.swap_memory, | |||
// name=forward_ctxt.name, | |||
// grad_state=self) | |||
// real_cnt = outer_grad_state.AddBackpropAccumulatedValue(history_cnt, cnt) | |||
// self._grad_index = self._grad_context.AddBackpropLoopCounter( | |||
// real_cnt, outer_grad_state) | |||
// outer_grad_ctxt.Exit() | |||
// else: | |||
// if outer_forward_ctxt: | |||
// outer_forward_ctxt.Enter() | |||
// self._grad_context = WhileContext( | |||
// maximum_iterations=forward_ctxt.maximum_iterations, | |||
// parallel_iterations=forward_ctxt.parallel_iterations, | |||
// back_prop=forward_ctxt.back_prop, | |||
// swap_memory=forward_ctxt.swap_memory, | |||
// name=forward_ctxt.name, | |||
// grad_state=self) | |||
// self._grad_index = self._grad_context.AddBackpropLoopCounter( | |||
// cnt, outer_grad_state) | |||
// if outer_forward_ctxt: | |||
// outer_forward_ctxt.Exit() | |||
// @property | |||
// def outer_grad_state(self): | |||
// """The grad loop state for outer loop.""" | |||
// return self._outer_grad_state | |||
// @property | |||
// def forward_context(self): | |||
// """The while loop context for forward.""" | |||
// return self._forward_context | |||
// @property | |||
// def forward_index(self): | |||
// """The loop index of forward loop.""" | |||
// return self._forward_index | |||
// @property | |||
// def forward_sync(self): | |||
// """A control trigger node for synchronization in the forward loop. | |||
// One main use is to keep the push ops of a stack executed in the | |||
// iteration order. | |||
// """ | |||
// if self._forward_sync is None: | |||
// with ops.control_dependencies(None): | |||
// self._forward_sync = control_trigger(name="f_sync") | |||
// self._forward_sync._set_control_flow_context(self._forward_context) | |||
// self._forward_index.op._add_control_input(self._forward_sync) | |||
// return self._forward_sync | |||
// @property | |||
// def grad_context(self): | |||
// """The corresponding WhileContext for gradient.""" | |||
// return self._grad_context | |||
// @property | |||
// def grad_index(self): | |||
// """The loop index of backprop loop.""" | |||
// return self._grad_index | |||
// @property | |||
// def grad_sync(self): | |||
// """A control trigger node for synchronization in the grad loop. | |||
// One main use is to keep the pop ops of a stack executed in the | |||
// iteration order. | |||
// """ | |||
// if self._grad_sync is None: | |||
// with ops.control_dependencies(None): | |||
// self._grad_sync = control_trigger(name="b_sync") | |||
// self._grad_sync._set_control_flow_context(self._grad_context) | |||
// self._grad_index.op._add_control_input(self._grad_sync) | |||
// if self._grad_context.outer_context: | |||
// self._grad_context.outer_context.AddInnerOp(self._grad_sync) | |||
// return self._grad_sync | |||
// @property | |||
// def history_map(self): | |||
// """The map that records all the tensors needed for backprop.""" | |||
// return self._history_map | |||
// @property | |||
// def switch_map(self): | |||
// """The map that records all the Switch ops for the while loop.""" | |||
// return self._switch_map | |||
// @property | |||
// def unused_exits(self): | |||
// """The list of "unused" exits.""" | |||
// return self._unused_exits | |||
// @property | |||
// def deferred_exits(self): | |||
// """The list of "deferred" exits.""" | |||
// return self._deferred_exits | |||
// @property | |||
// def forward_loop_exits(self): | |||
// """The list of exits of the forward loop.""" | |||
// return self._forward_loop_exits | |||
// @property | |||
// def pending_exits_count(self): | |||
// """The number of exits we expect to see but haven't.""" | |||
// return self._pending_exits_count | |||
// @pending_exits_count.setter | |||
// def pending_exits_count(self, cnt): | |||
// """Set the pending count to cnt.""" | |||
// self._pending_exits_count = cnt | |||
/// <summary> | |||
/// Add an accumulator for each forward tensor that is needed in backprop. | |||
/// | |||
/// This is added to the forward loop at the first time when a tensor | |||
/// in the forward loop is used by backprop gradient computation loop. | |||
/// We create an accumulator that accumulates the value of tensor at each | |||
/// iteration. Called in the control flow context where gradients() is called. | |||
/// | |||
/// The pseudocode is: | |||
/// ``` | |||
/// acc = stack(); | |||
/// while (_pivot) { | |||
/// acc = stack_push(acc, value); | |||
/// } | |||
/// ``` | |||
/// | |||
/// We make sure that the stack push op in one iteration is executed before | |||
/// next iteration. This is achieved by adding a control edge from | |||
/// `forward_index.op.inputs[0].op` to the push op, and another control | |||
/// edge from the push op to either `forward_index.op` or `forward_sync`. | |||
/// </summary> | |||
/// <param name="value"> The source tensor in forward that is to be accumulated.</param> | |||
/// <param name="dead_branch"> True iff the tensor is on a dead branch of a cond.</param> | |||
/// <returns>The stack that contains the accumulated history of the tensor.</returns> | |||
public Tensor AddForwardAccumulator(Tensor value, bool dead_branch = false) | |||
{ | |||
throw new NotImplementedException("AddForwardAccumulator"); | |||
// # curr_ctxt is the context that tf.gradients was called in. | |||
// with self._forward_index.graph.as_default(): | |||
// curr_ctxt = ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access | |||
// with ops.control_dependencies(None): | |||
// if curr_ctxt: | |||
// curr_ctxt.Enter() | |||
// with ops.colocate_with(value): | |||
// # We only need to pass maximum_iterations to the stack if | |||
// # we're inside an XLA context. | |||
// if not util.IsInXLAContext(value.op): | |||
// max_size = constant_op.constant(-1, dtypes.int32) | |||
// else: | |||
// max_size = GetMaxSizeFromNestedMaximumIterations( | |||
// value, self.forward_context) | |||
// acc = gen_data_flow_ops.stack_v2( | |||
// max_size=max_size, elem_type=value.dtype.base_dtype, name="f_acc") | |||
// if curr_ctxt: | |||
// curr_ctxt.Exit() | |||
// # Make acc available in the forward context. | |||
// enter_acc = self.forward_context.AddValue(acc) | |||
// # Add the stack_push op in the context of value.op. | |||
// swap_enabled = self.forward_context.swap_memory | |||
// value_ctxt = util.GetOutputContext(value.op) | |||
// if value_ctxt == self.forward_context: | |||
// # value is not nested in the forward context. | |||
// self.forward_context.Enter() | |||
// push = gen_data_flow_ops.stack_push_v2( | |||
// enter_acc, value, swap_memory=swap_enabled) | |||
// self.forward_context.Exit() | |||
// # Protect stack push and order it before forward_index. | |||
// self.forward_index.op._add_control_input(push.op) | |||
// else: | |||
// # value is in a cond context within the forward context. | |||
// if not isinstance(value_ctxt, CondContext): | |||
// raise TypeError("value_ctxt is not a CondContext: %s" % value_ctxt) | |||
// if dead_branch: | |||
// # The special case for creating a zero tensor for a dead | |||
// # branch of a switch. See ControlFlowState.ZerosLike(). | |||
// value_ctxt.outer_context.Enter() | |||
// push = gen_data_flow_ops.stack_push_v2( | |||
// enter_acc, value, swap_memory=swap_enabled) | |||
// value_ctxt.outer_context.Exit() | |||
// push.op._set_control_flow_context(value_ctxt) | |||
// else: | |||
// value_ctxt.Enter() | |||
// push = gen_data_flow_ops.stack_push_v2( | |||
// enter_acc, value, swap_memory=swap_enabled) | |||
// value_ctxt.Exit() | |||
// # Protect stack push and order it before forward_sync. | |||
// self.forward_sync._add_control_input(push.op) | |||
// # Order stack push after the successor of forward_index | |||
// add_op = self.forward_index.op.inputs[0].op | |||
// push.op._add_control_input(add_op) | |||
// return acc | |||
} | |||
// """Add the getter for an accumulated value in the grad context. | |||
// | |||
// This is added to the backprop loop. Called in the grad context to | |||
// get the value of an accumulated value. The stack pop op must be guarded | |||
// by the pred of the controlling cond. | |||
// | |||
// Args: | |||
// history_value: The history (a stack) of a value. | |||
// value: The value that is pushed onto the stack. | |||
// dead_branch: True iff the tensor is on a dead branch of a cond. | |||
// | |||
// Returns: | |||
// The current value (the top of the stack). | |||
// """ | |||
public Tensor AddBackpropAccumulatedValue(Tensor history_value, Tensor value, bool dead_branch= false) | |||
{ | |||
throw new NotImplementedException(); | |||
// history_ctxt = history_value.op._get_control_flow_context() | |||
// # Find the cond context that controls history_value if any. | |||
// cond_ctxt = None | |||
// value_ctxt = value.op._get_control_flow_context() | |||
// while value_ctxt and value_ctxt != history_ctxt: | |||
// if isinstance(value_ctxt, CondContext): | |||
// cond_ctxt = value_ctxt | |||
// break | |||
// value_ctxt = value_ctxt.outer_context | |||
// with ops.control_dependencies(None): | |||
// self.grad_context.Enter() | |||
// if cond_ctxt: | |||
// # Guard stack pop with a switch if it is controlled by a cond. | |||
// grad_state = self | |||
// pred = None | |||
// while pred is None and grad_state: | |||
// pred = grad_state.history_map.get(cond_ctxt.pred.name) | |||
// grad_state = grad_state.outer_grad_state | |||
// if pred is None: | |||
// pred = cond_ctxt.pred | |||
// branch = (1 - cond_ctxt.branch) if dead_branch else cond_ctxt.branch | |||
// history_value = _SwitchRefOrTensor(history_value, pred)[branch] | |||
// pop = gen_data_flow_ops.stack_pop_v2(history_value, | |||
// value.dtype.base_dtype) | |||
// pop.set_shape(value.get_shape()) | |||
// self.grad_context.Exit() | |||
// parallel_iterations = self.grad_context.parallel_iterations | |||
// if parallel_iterations > 1: | |||
// # All pops are ordered after pivot_for_body and before grad_sync. | |||
// self.grad_sync._add_control_input(pop.op) | |||
// return pop | |||
} | |||
// def GetRealValue(self, value): | |||
// """Get the real value of `value`. | |||
// If backprop "uses" a value produced by forward inference, an accumulator | |||
// is added in the forward loop to accumulate its values. We use the | |||
// accumulated value. This method must be called in the grad loop context. | |||
// `value` must be in forward and needed for backprop. | |||
// Args: | |||
// value: A tensor to be captured. | |||
// Returns: | |||
// The same tensor obtained from the saved history. | |||
// """ | |||
// assert value.op.type not in ["Variable", "VariableV2"] | |||
// real_value = self._history_map.get(value.name) | |||
// if real_value is None: | |||
// cur_value = value | |||
// cur_grad_state = self | |||
// while True: | |||
// enter_op = util.GetLoopConstantEnter(cur_value) | |||
// if enter_op: | |||
// # Special case: cur_value comes from a constant Enter node. | |||
// cur_value = enter_op.inputs[0] | |||
// cur_grad_state = cur_grad_state.outer_grad_state | |||
// if cur_grad_state is None: | |||
// # We are now outside all nested loops for this gradient(), | |||
// # so `value` is a loop invariant and there is no need to | |||
// # save the history of value. Just make cur_value to enter | |||
// # the right control flow context. | |||
// real_value = self._grad_context.AddValue(cur_value) | |||
// break | |||
// elif constant_op.is_constant(cur_value): | |||
// # If the value to be forwarded is a constant, clone the constant in | |||
// # the gradient loop rather than using a stack. | |||
// # TODO(phawkins): consider hoisting the constant out of the loop | |||
// # instead. | |||
// real_value = constant_op.constant( | |||
// tensor_util.constant_value(cur_value), dtype=cur_value.dtype) | |||
// break | |||
// else: | |||
// # Record the history of this value in forward_ctxt. | |||
// self._grad_context.Exit() | |||
// history_value = cur_grad_state.AddForwardAccumulator(cur_value) | |||
// self._grad_context.Enter() | |||
// break | |||
// if real_value is None: | |||
// # Add the stack pop op in the grad context. | |||
// real_value = cur_grad_state.AddBackpropAccumulatedValue( | |||
// history_value, cur_value) | |||
// if cur_grad_state != self: | |||
// real_value = self._grad_context.AddValue(real_value) | |||
// self._history_map[value.name] = real_value | |||
// return real_value | |||
} | |||
} |
@@ -4,13 +4,15 @@ using System.Text; | |||
namespace Tensorflow | |||
{ | |||
public interface IControlFlowContext | |||
{ | |||
void AddOp(Operation op); | |||
IControlFlowContext outer_context { get; } | |||
HashSet<string> values { get; } | |||
Tensor AddValue(Tensor val); | |||
void AddInnerOp(Operation resultOp); | |||
object to_proto(); | |||
} | |||
// henon: this was too much trouble. there is no value just cost to use an interface here. | |||
//public interface IControlFlowContext | |||
//{ | |||
// void AddOp(Operation op); | |||
// IControlFlowContext outer_context { get; } | |||
// HashSet<string> values { get; } | |||
// Tensor pivot { get; } | |||
// Tensor AddValue(Tensor val); | |||
// void AddInnerOp(Operation resultOp); | |||
// object to_proto(); | |||
//} | |||
} |
@@ -1,11 +1,26 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Operations.ControlFlows; | |||
namespace Tensorflow.Operations | |||
{ | |||
public class WhileContext : ControlFlowContext | |||
{ | |||
private bool _back_prop=true; | |||
private GradLoopState _grad_state =null; | |||
public override WhileContext GetWhileContext() | |||
{ | |||
return this; | |||
} | |||
public override GradLoopState grad_state => _grad_state; | |||
public override bool back_prop => _back_prop; | |||
public static WhileContext from_proto(object proto) | |||
{ | |||
throw new NotImplementedException(); | |||
@@ -7,7 +7,7 @@ namespace Tensorflow | |||
{ | |||
public partial class Operation | |||
{ | |||
private IControlFlowContext _control_flow_context; | |||
private ControlFlowContext _control_flow_context; | |||
/// <summary> | |||
/// Add this op to its control flow context. | |||
@@ -39,12 +39,12 @@ namespace Tensorflow | |||
_add_control_input(op); | |||
} | |||
public void _set_control_flow_context(IControlFlowContext ctx) | |||
public void _set_control_flow_context(ControlFlowContext ctx) | |||
{ | |||
_control_flow_context = ctx; | |||
} | |||
public IControlFlowContext _get_control_flow_context() | |||
public ControlFlowContext _get_control_flow_context() | |||
{ | |||
return _control_flow_context; | |||
} | |||
@@ -3,6 +3,7 @@ using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
using Tensorflow.Operations; | |||
using Tensorflow.Operations.ControlFlows; | |||
using util = Tensorflow.control_flow_util; | |||
namespace Tensorflow | |||
@@ -93,9 +94,9 @@ namespace Tensorflow | |||
/// <param name="between_op_list"></param> | |||
/// <param name="between_ops"></param> | |||
/// <param name="colocate_gradients_with_ops"></param> | |||
public static object MaybeCreateControlFlowState(List<Operation> between_op_list, List<Operation> between_ops, bool colocate_gradients_with_ops) | |||
public static ControlFlowState MaybeCreateControlFlowState(List<Operation> between_op_list, List<Operation> between_ops, bool colocate_gradients_with_ops) | |||
{ | |||
object loop_state = null; | |||
ControlFlowState loop_state = null; | |||
foreach (var op in between_op_list) | |||
{ | |||
@@ -103,7 +104,7 @@ namespace Tensorflow | |||
{ | |||
if(loop_state == null) | |||
{ | |||
// loop_state = ControlFlowState(); | |||
loop_state = new ControlFlowState(); | |||
} | |||
} | |||
} | |||
@@ -207,7 +208,7 @@ namespace Tensorflow | |||
/// `(output_false, output_true)`: If `pred` is true, data will be forwarded to | |||
/// `output_true`, otherwise it goes to `output_false`. | |||
/// </returns> | |||
public static (Tensor, Tensor) _SwitchRefOrTensor(Tensor data, Tensor pred, string name = "Switch") | |||
public static Tensor[] _SwitchRefOrTensor(Tensor data, Tensor pred, string name = "Switch") | |||
{ | |||
data = ops.convert_to_tensor_or_indexed_slices(data, name: "data"); | |||
// NOTE(vrv): ops.colocate_with(data, ignore_existing=True) below | |||
@@ -298,7 +299,9 @@ namespace Tensorflow | |||
*/ | |||
// Add the Switch to the graph. | |||
var (p_2, p_1) = @switch(pred, pred); | |||
var switch_result= @switch(pred, pred); | |||
var p_2=switch_result[0]; | |||
var p_1 = switch_result[1]; | |||
var pivot_1 = array_ops.identity(p_1, name: "switch_t"); | |||
var pivot_2 = array_ops.identity(p_2, name: "switch_f"); | |||
pred = array_ops.identity(pred, name: "pred_id"); | |||
@@ -379,7 +382,9 @@ namespace Tensorflow | |||
return with(ops.name_scope(name, "cond", new { pred }), delegate | |||
{ | |||
// Add the Switch to the graph. | |||
var (p_2, p_1) = @switch(pred, pred); | |||
var switch_result = @switch(pred, pred); | |||
var p_2 = switch_result[0]; | |||
var p_1 = switch_result[1]; | |||
var pivot_1 = array_ops.identity(p_1, name: "switch_t"); | |||
var pivot_2 = array_ops.identity(p_2, name: "switch_f"); | |||
pred = array_ops.identity(pred, name: "pred_id"); | |||
@@ -460,7 +465,7 @@ namespace Tensorflow | |||
/// <param name="pred"></param> | |||
/// <param name="dtype"></param> | |||
/// <param name="name"></param> | |||
public static (Tensor, Tensor) @switch(Tensor data, | |||
public static Tensor[] @switch(Tensor data, | |||
Tensor pred, | |||
TF_DataType dtype = TF_DataType.DtInvalid, | |||
string name = null) | |||
@@ -30,7 +30,7 @@ namespace Tensorflow | |||
/// <summary> | |||
/// Return the control flow context for the output of an op. | |||
/// </summary> | |||
public static IControlFlowContext GetOutputContext(Operation op) | |||
public static ControlFlowContext GetOutputContext(Operation op) | |||
{ | |||
var ctxt = op._get_control_flow_context(); | |||
// Exit nodes usually have a control flow context, except in the case where the | |||
@@ -33,14 +33,14 @@ namespace Tensorflow | |||
/// output_false: A `Tensor`. Has the same type as `data`. | |||
/// output_true: A `Tensor`. Has the same type as `data`. | |||
/// </returns> | |||
public static (Tensor, Tensor) @switch(Tensor data, Tensor pred, string name = null) | |||
public static Tensor[] @switch(Tensor data, Tensor pred, string name = null) | |||
{ | |||
var _op = _op_def_lib._apply_op_helper("Switch", name, new { data, pred }); | |||
var _inputs_flat = _op.inputs; | |||
var _attrs = ("T", _op.get_attr("T")); | |||
// TODO: missing original code | |||
//_execute.record_gradient("Switch", _inputs_flat, _attrs, _result, name); | |||
return (_op.outputs[0], _op.outputs[1]); | |||
return new []{_op.outputs[0], _op.outputs[1]}; | |||
} | |||
public static (Tensor, Tensor) merge(Tensor[] inputs, string name = null) | |||