From 692faaa9bd878f47a4540f1c3f27a7c6fa8c7b9a Mon Sep 17 00:00:00 2001 From: Meinrad Recheis Date: Wed, 17 Apr 2019 10:44:50 +0200 Subject: [PATCH] gradients/flow control: added much missing structure --- .../Gradients/control_flow_grad.py.cs | 229 +++++++++- .../Graphs/Graph.Control.cs | 7 +- .../Graphs/_ControlDependenciesController.cs | 3 +- .../Operations/ControlFlows/CondContext.cs | 51 ++- .../ControlFlows/ControlFlowContext.cs | 33 +- .../ControlFlows/ControlFlowState.cs | 277 ++++++++++++ .../Operations/ControlFlows/GradLoopState.cs | 398 ++++++++++++++++++ .../ControlFlows/IControlFlowContext.cs | 20 +- .../Operations/ControlFlows/WhileContext.cs | 15 + .../Operations/Operation.Control.cs | 6 +- .../Operations/control_flow_ops.py.cs | 19 +- .../Operations/control_flow_util.py.cs | 2 +- .../Operations/gen_control_flow_ops.py.cs | 4 +- 13 files changed, 1014 insertions(+), 50 deletions(-) create mode 100644 src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowState.cs create mode 100644 src/TensorFlowNET.Core/Operations/ControlFlows/GradLoopState.cs diff --git a/src/TensorFlowNET.Core/Gradients/control_flow_grad.py.cs b/src/TensorFlowNET.Core/Gradients/control_flow_grad.py.cs index afc87d45..de61e52b 100644 --- a/src/TensorFlowNET.Core/Gradients/control_flow_grad.py.cs +++ b/src/TensorFlowNET.Core/Gradients/control_flow_grad.py.cs @@ -1,12 +1,79 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Text; using Tensorflow.Operations; namespace Tensorflow.Gradients { + /// + /// Gradients for operators defined in control_flow_ops.py.cs + /// public class control_flow_grad { + /// + /// Gradients for a Switch op is calculated using a Merge op. + /// + /// If the switch is a loop switch, it will be visited twice. We create + /// the merge on the first visit, and update the other input of the merge + /// on the second visit. A next_iteration is also added on second visit. + /// + /// + public Tensor[] _SwitchGrad(Tensor op, Tensor[] grads) + { + throw new NotImplementedException("_SwitchGrad"); + //graph = ops.get_default_graph() + //# pylint: disable=protected-access + //op_ctxt = op._get_control_flow_context() + //grad_ctxt = graph._get_control_flow_context() + //# pylint: enable=protected-access + //if isinstance(op_ctxt, WhileContext): + // merge_grad = grad_ctxt.grad_state.switch_map.get(op) + // if merge_grad is not None: + // # This is the second time this Switch is visited. It comes from + // # the non-exit branch of the Switch, so update the second input + // # to the Merge. + // # TODO(yuanbyu): Perform shape inference with this new input. + // if grad[1] is not None: + // # pylint: disable=protected-access + // control_flow_ops._AddNextAndBackEdge(merge_grad, grad[1], + // enforce_shape_invariant=False) + // # pylint: enable=protected-access + // return None, None + // elif grad[0] is not None: + // # This is the first time this Switch is visited. It comes from + // # the Exit branch, which is grad[0]. grad[1] is empty at this point. + // # Use grad[0] for both inputs to merge for now, but update the second + // # input of merge when we see this Switch the second time. + // merge_grad = merge([grad[0], grad[0]], name="b_switch")[0] + // grad_ctxt.grad_state.switch_map[op] = merge_grad + // return merge_grad, None + // else: + // # This is the first time this Switch is visited. It comes from the + // # Identity branch. Such a Switch has `None` gradient for the Exit branch, + // # meaning the output is not differentiable. + // return None, None + //elif isinstance(op_ctxt, CondContext): + // zero_grad = grad[1 - op_ctxt.branch] + // # At this point, we have created zero_grad guarded by the right switch. + // # Unfortunately, we may still get None here for not trainable data types. + // if zero_grad is None: + // # For resource variables we get None always on the other branch, so bypass + // # this. + // if op.inputs[0].dtype == dtypes.resource: + // return merge( + // [grad[op_ctxt.branch]] * 2, name="cond_resource_grad")[0], None + // return None, None + // return merge(grad, name="cond_grad")[0], None + //else: + // false_grad = switch(grad[0], op.inputs[1])[0] + // true_grad = switch(grad[1], op.inputs[1])[1] + // return merge([false_grad, true_grad])[0], None + } + + /// + /// Gradients for a Merge op are calculated using a Switch op. + /// public static Tensor[] _MergeGrad(Operation op, Tensor[] grads) { var grad = grads[0]; @@ -14,10 +81,164 @@ namespace Tensorflow.Gradients var input_op = op.inputs[0].op; var graph = ops.get_default_graph(); var op_ctxt = control_flow_util.GetOutputContext(input_op); - var pred = (op_ctxt as CondContext).pred; + var grad_ctxt = graph._get_control_flow_context(); + switch (op_ctxt) + { + case WhileContext cwhile: + { + return control_flow_ops._SwitchRefOrTensor(grad, grad_ctxt.pivot); + } + case CondContext ccond: + { + var pred = ccond.pred; + if (grad_ctxt != null && grad_ctxt.grad_state != null) + { + //# This Merge node is part of a cond within a loop. + //# The backprop needs to have the value of this predicate for every + //# iteration. So we must have its values accumulated in the forward, and + //# use the accumulated values as the predicate for this backprop switch. + var grad_state = grad_ctxt.grad_state; + var real_pred = grad_state.history_map[pred.name] as Tensor; + if (real_pred == null) + { + //# Remember the value of pred for every iteration. + grad_ctxt = grad_state.grad_context; + grad_ctxt.Exit(); + var history_pred = grad_state.AddForwardAccumulator(pred); + grad_ctxt.Enter(); + + //# Add the stack pop op. If pred.op is in a (outer) CondContext, + //# the stack pop will be guarded with a switch. + real_pred = grad_state.AddBackpropAccumulatedValue(history_pred, pred); + grad_state.history_map[pred.name] = real_pred; + } + pred = real_pred; + } + var results = control_flow_ops._SwitchRefOrTensor(grad, pred, name: "cond_grad"); + return results; + } + default: + { + var num_inputs = op.inputs.Length; + var cond = new Tensor[num_inputs]; + for (int i = 0; i < num_inputs; i++) + cond[i] = math_ops.equal(op.outputs[1], i); + var result = cond.Select(t => control_flow_ops._SwitchRefOrTensor(grad, t)[1]).ToArray(); + return result; + } + } - var results = control_flow_ops._SwitchRefOrTensor(grad, pred, name: "cond_grad"); - return new Tensor[] { results.Item1, results.Item2 }; } - } + + public Tensor[] _RefMergeGrad(Operation op, Tensor[] grads) + { + return _MergeGrad(op, grads); + } + + /// + /// Gradients for an exit op are calculated using an Enter op. + /// + public Tensor[] _ExitGrad(Operation op, Tensor[] grads) + { + throw new NotImplementedException("_ExitGrad"); + // graph = ops.get_default_graph() + //# pylint: disable=protected-access + // op_ctxt = op._get_control_flow_context() + // grad_ctxt = graph._get_control_flow_context() + // # pylint: enable=protected-access + // if not grad_ctxt.back_prop: + // # The flag `back_prop` is set by users to suppress gradient + // # computation for this loop. If the attribute `back_prop` is false, + // # no gradient computation. + // return None + + // if op_ctxt.grad_state: + // raise TypeError("Second-order gradient for while loops not supported.") + + // if isinstance(grad, ops.Tensor) : + // grad_ctxt.AddName(grad.name) + // else: + // if not isinstance(grad, (ops.IndexedSlices, sparse_tensor.SparseTensor)): + // raise TypeError("Type %s not supported" % type(grad)) + // grad_ctxt.AddName(grad.values.name) + // grad_ctxt.AddName(grad.indices.name) + // dense_shape = grad.dense_shape + // if dense_shape is not None: + // grad_ctxt.AddName(dense_shape.name) + // grad_ctxt.Enter() + // # pylint: disable=protected-access + // result = control_flow_ops._Enter( + // grad, grad_ctxt.name, is_constant=False, + // parallel_iterations=grad_ctxt.parallel_iterations, + // name="b_exit") + // # pylint: enable=protected-access + // grad_ctxt.loop_enters.append(result) + // grad_ctxt.Exit() + // return result + } + + /// + /// A forward next_iteration is translated into a backprop identity. + /// + /// Note that the backprop next_iteration is added in switch grad. + /// + public (object, Tensor[]) _NextIterationGrad(object _, Tensor[] grad) + { + return (_, grad); + } + + public (object, Tensor[]) _RefNextIterationGrad(object _, Tensor[] grad) + { + return (_, grad); + } + + /// + /// Gradients for an Enter are calculated using an Exit op. + /// + /// For loop variables, grad is the gradient so just add an exit. + /// For loop invariants, we need to add an accumulator loop. + /// + public (object, Tensor[]) _EnterGrad(Tensor op, Tensor[] grad) + { + throw new NotImplementedException("_EnterGrad"); + // graph = ops.get_default_graph() + //# pylint: disable=protected-access + // grad_ctxt = graph._get_control_flow_context() + // # pylint: enable=protected-access + // if not grad_ctxt.back_prop: + // # Skip gradient computation, if the attribute `back_prop` is false. + // return grad + // if grad_ctxt.grad_state is None: + // # Pass the gradient through if we are not in a gradient while context. + // return grad + // if op.get_attr("is_constant"): + // # Add a gradient accumulator for each loop invariant. + // if isinstance(grad, ops.Tensor) : + // result = grad_ctxt.AddBackpropAccumulator(op, grad) + // elif isinstance(grad, ops.IndexedSlices) : + // result = grad_ctxt.AddBackpropIndexedSlicesAccumulator(op, grad) + // else: + // # TODO(yuanbyu, lukasr): Add support for SparseTensor. + // raise TypeError("Type %s not supported" % type(grad)) + // else: + // result = exit(grad) + // grad_ctxt.loop_exits.append(result) + // grad_ctxt.ExitResult([result]) + // return result + } + + public (object, Tensor[]) _RefEnterGrad(Tensor op, Tensor[] grad) + { + return _EnterGrad(op, grad); + } + + /// + /// Stop backprop for the predicate of a while loop. + /// + public object _LoopCondGrad(object _) + { + return null; + } + + } } diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Control.cs b/src/TensorFlowNET.Core/Graphs/Graph.Control.cs index 42cf1a17..fda9ff01 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.Control.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.Control.cs @@ -3,13 +3,14 @@ using System.Collections.Generic; using System.Linq; using System.Text; using Tensorflow.Eager; +using Tensorflow.Operations; namespace Tensorflow { public partial class Graph { // Current control flow context. It could be either CondContext or WhileContext - public IControlFlowContext _control_flow_context; + public ControlFlowContext _control_flow_context; // represents the nested with(...) statements public List<_ControlDependenciesController> _control_dependencies_stack { get; set; } = new List<_ControlDependenciesController>(); @@ -97,7 +98,7 @@ namespace Tensorflow /// Returns the current control flow context. /// /// A context object. - public IControlFlowContext _get_control_flow_context() + public ControlFlowContext _get_control_flow_context() { return _control_flow_context; } @@ -106,7 +107,7 @@ namespace Tensorflow /// Sets the current control flow context. /// /// a context object. - public void _set_control_flow_context(IControlFlowContext ctx) + public void _set_control_flow_context(ControlFlowContext ctx) { _control_flow_context = ctx; } diff --git a/src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs b/src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs index 36832b35..047624d5 100644 --- a/src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs +++ b/src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Text; using Tensorflow.Eager; +using Tensorflow.Operations; namespace Tensorflow { @@ -15,7 +16,7 @@ namespace Tensorflow private List _seen_nodes; private List<_ControlDependenciesController> _old_stack; private bool _new_stack; - private IControlFlowContext _old_control_flow_context; + private ControlFlowContext _old_control_flow_context; public ITensorOrOperation[] control_inputs => _control_inputs_val.ToArray(); diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs index 47908e05..254df0cf 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Linq; using System.Text; +using Tensorflow.Operations.ControlFlows; namespace Tensorflow.Operations { @@ -107,8 +108,8 @@ namespace Tensorflow.Operations with(ops.control_dependencies(null), ctrl => { - var (r0, r1) = control_flow_ops._SwitchRefOrTensor(result, _pred); - result = new[] { r0, r1 }[_branch]; + var results = control_flow_ops._SwitchRefOrTensor(result, _pred); + result = results[_branch]; if (_outer_context != null) _outer_context.AddInnerOp(result.op); }); @@ -118,7 +119,7 @@ namespace Tensorflow.Operations // Mark Switch output as seen by this context and any outer contexts, // just like what we do for normal op outputs in _AddOpInternal() below. - IControlFlowContext ctxt = this; + ControlFlowContext ctxt = this; while (ctxt != null) { ctxt.values.Add(result.name); @@ -223,8 +224,8 @@ namespace Tensorflow.Operations _values.Add(real_val.name); _external_values[real_val.name] = real_val; } - var (t0, t1) = control_flow_ops._SwitchRefOrTensor(real_val, _pred); - real_val = new[] {t0, t1}[_branch]; + var results = control_flow_ops._SwitchRefOrTensor(real_val, _pred); + real_val = results[_branch]; _external_values[val.name] = real_val; } else @@ -238,8 +239,8 @@ namespace Tensorflow.Operations return real_val; } - protected override void _AddOpInternal(Operation op) - { + protected override void _AddOpInternal(Operation op) + { if (op.inputs.Length == 0) { //If we're in a while loop, remove any control inputs from outside the @@ -282,11 +283,11 @@ namespace Tensorflow.Operations // TODO: implement below code dependencies //if (op.graph._is_function(op.type) || op.type == "SymbolicGradient") // op._add_control_input(_pivot.op); - } - - // Mark op's outputs as seen by this context and any outer contexts. + } + + // Mark op's outputs as seen by this context and any outer contexts. var output_names = op.outputs.Select(x => x.name).ToArray(); - IControlFlowContext ctxt = this; + ControlFlowContext ctxt = this; while (ctxt != null) { foreach (var name in output_names) @@ -298,9 +299,31 @@ namespace Tensorflow.Operations op.graph.prevent_fetching(op); if (_outer_context != null) - _outer_context.AddInnerOp(op); - } - + _outer_context.AddInnerOp(op); + } + + public override GradLoopState grad_state + { + get + { + var whc = GetWhileContext(); + if (whc != null) + return whc.grad_state; + return null; + } + } + + public override bool back_prop + { + get + { + var whc = GetWhileContext(); + if (whc != null) + return whc.back_prop; + return false; + } + } + public CondContextDef to_proto(string export_scope) { throw new NotImplementedException(); diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs index 56b38846..48a519db 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Linq; using System.Text; +using Tensorflow.Operations.ControlFlows; namespace Tensorflow.Operations { @@ -22,21 +23,25 @@ namespace Tensorflow.Operations /// 4. A ControlFlowContext has _context_stack. /// Pushed and popped by ctxt.Enter() and ctxt.Exit() /// - public abstract class ControlFlowContext : Python, IPython, IControlFlowContext + public abstract class ControlFlowContext : Python, IPython { /// /// The predicate tensor in this branch /// protected Tensor _pivot; + public Tensor pivot + { + get => _pivot; + } - protected Stack _context_stack; - protected IControlFlowContext _outer_context; + protected Stack _context_stack; + protected ControlFlowContext _outer_context; protected Dictionary _external_values; public ControlFlowContext() { - _context_stack = new Stack(); + _context_stack = new Stack(); } public string name { get => _name; } @@ -111,8 +116,13 @@ namespace Tensorflow.Operations _AddOpInternal(op); } - public IControlFlowContext outer_context { get { return _outer_context; } } + public ControlFlowContext outer_context { get { return _outer_context; } } public HashSet values => _values; + + public virtual GradLoopState grad_state => throw new NotImplementedException("abstract method"); + + public virtual bool back_prop => throw new NotImplementedException("abstract method"); + public virtual Tensor AddValue(Tensor val) { // to be overridden @@ -147,7 +157,7 @@ namespace Tensorflow.Operations /// /// Returns true if `maybe_containing_ctxt` is or contains `ctxt`. /// - public static bool IsContainingContext(IControlFlowContext ctxt, ControlFlowContext maybe_containing_ctxt) + public static bool IsContainingContext(ControlFlowContext ctxt, ControlFlowContext maybe_containing_ctxt) { while (ctxt != maybe_containing_ctxt) { @@ -164,6 +174,16 @@ namespace Tensorflow.Operations var internal_control_inputs = op.control_inputs; } + /// + /// Return the while context containing this context + /// + public virtual WhileContext GetWhileContext() + { + if (_outer_context != null) + return _outer_context.GetWhileContext(); + return null; + } + public object to_proto() { throw new NotImplementedException(); @@ -173,5 +193,6 @@ namespace Tensorflow.Operations public void Dispose() { } + } } diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowState.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowState.cs new file mode 100644 index 00000000..c87ba1c6 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowState.cs @@ -0,0 +1,277 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Operations.ControlFlows +{ + /// + /// Maintain the mapping from the loops to their grad states. + /// + public class ControlFlowState + { + //class ControlFlowState(object): + // """Maintain the mapping from the loops to their grad states.""" + + // def __init__(self): + // self._map = {} # maps forward loop context to GradLoopState + + // def GetGradState(self, op, before): + // """Return the grad state for this op if it's in a forward loop context.""" + // if before and util.IsLoopExit(op): + // forward_ctxt = op._get_control_flow_context() + // forward_ctxt = forward_ctxt.outer_context + // if forward_ctxt: + // forward_ctxt = forward_ctxt.GetWhileContext() + // else: + // forward_ctxt = _GetWhileContext(op) + // if forward_ctxt: + // return self._map.get(forward_ctxt) + // return None + + // def ProcessUnusedLoopExits(self, pending_count, to_ops_set): + // """Process all the "unused" loop exits. + + // The "unused" exits of the loops are added to `unused_exits`. An exit is + // unused if its pending_count is 0. If there is an exit with real gradient, + // all these deferred exits will enter the backprop loop with zero gradient. + // Otherwise, they will enter the backprop loop with None. As an example, + // people often write: + + // ```python + // v1, _ = tf.while_loop(p, b, [x1, x2]) + // result = gradients(v1, x1) + // ``` + + // The exit node for x2 is not included by the betweenness analysis. But we + // need to backprop x2 if x2 is involved in computing v1. + + // Args: + // pending_count: The number of backprop inputs for every op. + // to_ops_set: The set of ops for ys in gradients(ys, xs) + + // Returns: + // The set of unused loop exits that we know at this point we need + // to backprop. + // """ + // loop_exits = [] + // for grad_state in self._map.values(): + // for y in grad_state.forward_loop_exits: + // if pending_count[y.op] == 0: + // grad_state.pending_exits_count -= 1 + // if y.op not in to_ops_set: + // grad_state.unused_exits.append(y) + // if grad_state.pending_exits_count == 0: + // loop_exits.extend(grad_state.unused_exits) + // # Need to include Enters in backprop for higher-order gradients. + // for y in grad_state.forward_context.loop_enters: + // if pending_count[y.op] == 0: + // pending_count[y.op] = 1 + // return loop_exits + + // def EnterGradWhileContext(self, op, before): + // """Enter the WhileContext for gradient computation.""" + // grad_state = self.GetGradState(op, before) + // if grad_state: + // grad_state.grad_context.Enter() + + // def ExitGradWhileContext(self, op, before): + // """Exit the WhileContext for gradient computation.""" + // grad_state = self.GetGradState(op, before) + // if grad_state: + // grad_state.grad_context.Exit() + + // def AddWhileContext(self, op, between_op_list, between_ops): + // """Add the grad state for the while loop that op belongs to. + + // Note that op is an Exit, and this method must be called in + // the control flow context where gradients() is called. + + // Note that this method modifies `between_op_list` and `between_ops`. + // """ + // forward_ctxt = _GetWhileContext(op) + // grad_state = self._map.get(forward_ctxt) + // if grad_state is None: + // # This is a new while loop so create a grad state for it. + // outer_forward_ctxt = forward_ctxt.outer_context + // if outer_forward_ctxt: + // outer_forward_ctxt = outer_forward_ctxt.GetWhileContext() + // outer_grad_state = None + // if outer_forward_ctxt: + // outer_grad_state = self._map.get(outer_forward_ctxt) + // grad_state = GradLoopState(forward_ctxt, outer_grad_state) + // self._map[forward_ctxt] = grad_state + + // # We need to include all exits of a loop for backprop. + // for loop_exit in grad_state.forward_loop_exits: + // if loop_exit.op not in between_ops: + // between_ops.add(loop_exit.op) + // between_op_list.append(loop_exit.op) + + // def ZerosLikeForExit(self, val): + // """Create zeros_like gradient for a loop exit. + + // If the result of a loop variable is not used but is involved in + // computing the result of some needed loop variable, we create a + // zero-valued tensor that is fed as gradient for the Exit node of that + // loop variable. Note that val.op is an Exit, and this method must be + // called in the control flow context where gradients() is called. + + // Args: + // val: The output tensor of an Exit op. + + // Returns: + // A zero tensor of the same shape of val. + // """ + // val_shape = val.get_shape() + // forward_ctxt = val.op._get_control_flow_context() + // outer_forward_ctxt = forward_ctxt.outer_context + // if outer_forward_ctxt: + // outer_forward_ctxt = outer_forward_ctxt.GetWhileContext() + // outer_grad_state = None + // if outer_forward_ctxt: + // outer_grad_state = self._map.get(outer_forward_ctxt) + // if outer_grad_state: + // # This is a nested loop. + // if val_shape.is_fully_defined(): + // # If the shape is known statically, just create a zero tensor + // # with the right shape in the right context. + // outer_grad_state.grad_context.Enter() + // result = array_ops.zeros(val_shape.dims, val.dtype) + // outer_grad_state.grad_context.Exit() + // else: + // # Only the shape of value is needed for backprop. + // forward_ctxt.outer_context.Enter() + // shape = array_ops.shape_internal(val, optimize=False) + // forward_ctxt.outer_context.Exit() + // # Save the shape to a stack. + // history_shape = outer_grad_state.AddForwardAccumulator(shape) + // # Get the shape back from the stack. + // outer_grad_ctxt = outer_grad_state.grad_context + // outer_grad_ctxt.Enter() + // real_shape = outer_grad_state.AddBackpropAccumulatedValue( + // history_shape, shape) + // result = array_ops.zeros(real_shape, val.dtype) + // outer_grad_ctxt.Exit() + // else: + // # This is not a nested loop. + // if val_shape.is_fully_defined(): + // # If the shape is known statically, just create a zero tensor + // # with the right shape. + // result = array_ops.zeros(val_shape.dims, val.dtype) + // else: + // result = array_ops.zeros_like(val, optimize=False) + // return result + + // def ZerosLike(self, op, index): + // """Create zeros_like for the specified output of an op. + + // If op is in a while loop that is part of gradients(), this method + // must be called in its grad loop context. + + // Args: + // op: A tensorflow operation. + // index: the index for a specific output of the op. + + // Returns: + // A zero tensor of the same shape of op.outputs[index]. + // """ + // if util.IsLoopSwitch(op): + // return None + // if op.graph._building_function: # pylint: disable=protected-access + // # The optimization here is tricky to apply to functions + // return array_ops.zeros_like(op.outputs[index]) + // dead_branch = util.IsSwitch(op) + // forward_ctxt = _GetWhileContext(op) + // grad_state = self._map.get(forward_ctxt) + // if grad_state is None: + // # op is not in a while loop that is part of gradients(). + // return ZerosLikeOutsideLoop(op, index) + // op_ctxt = op._get_control_flow_context() + // val = ops.convert_to_tensor(op.outputs[index], name="tensor") + // shape = val.get_shape() + // if shape.is_fully_defined(): + // # If the shape is known statically, just create a zero tensor with + // # the right shape in the grad loop context. + // result = constant_op.constant(0, shape=shape.dims, dtype=val.dtype) + // if dead_branch: + // # op is a cond switch. Guard the zero tensor with a switch. + // pred = grad_state.history_map.get(op_ctxt.pred.name) + // branch = op_ctxt.branch + // result = _SwitchRefOrTensor(result, pred)[1 - branch] + // else: + // # Unknown shape so keep a history of the shape at runtime. + // if dead_branch: + // # Need to add a special switch to guard the value. + // pred = op_ctxt.pred + // branch = op_ctxt.branch + // op_ctxt.outer_context.Enter() + // val = _SwitchRefOrTensor(op.inputs[0], pred)[1 - branch] + // zeros_shape = array_ops.shape_internal(val, optimize=False) + // op_ctxt.outer_context.Exit() + // val.op._set_control_flow_context(op_ctxt) + // zeros_shape.op._set_control_flow_context(op_ctxt) + // else: + // op_ctxt.Enter() + // zeros_shape = array_ops.shape_internal(val, optimize=False) + // op_ctxt.Exit() + + // # Add forward accumulator for shape. + // grad_state.grad_context.Exit() + // history_zeros_shape = grad_state.AddForwardAccumulator( + // zeros_shape, dead_branch=dead_branch) + // grad_state.grad_context.Enter() + + // # Create a zero tensor with the right shape. + // shape = grad_state.AddBackpropAccumulatedValue(history_zeros_shape, + // zeros_shape, dead_branch) + // result = array_ops.zeros(shape, val.dtype) + // return result + + // def PostProcessing(self): + // """Perform postprocessing at the end of gradients(). + + // We have created the gradient graph at this point. So this function + // can be used to perform any postprocessing on the gradient graph. + // We currently perform the following postprocessing: + // 1. Patch the gradient graph if the output of a loop variable + // doesn't depend on its input. + // """ + // for _, grad_state in self._map.items(): + // for _, b_merge in grad_state.switch_map.items(): + // if b_merge.op.inputs[0] == b_merge.op.inputs[1]: + // # The value of this loop variable at iteration i+1 doesn't + // # depend on its value at iteration i. So use zeros as the + // # gradients for all iterations > 0. + // dtype = b_merge.op.inputs[0].dtype + // shape = b_merge.op.inputs[0].get_shape() + // # pylint: disable=protected-access + // if shape.is_fully_defined(): + // grad_state.grad_context.Enter() + // # Create a zeros and use it for iterations > 0. + // grad_val = constant_op.constant(0, dtype=dtype, shape=shape) + // next_grad_val = _NextIteration(grad_val) + // grad_state.grad_context.Exit() + // else: + // # Create a zeros in the outer grad context. + // outer_grad_ctxt = grad_state.grad_context.outer_context + // if outer_grad_ctxt: + // outer_grad_ctxt.Enter() + // enter_grad_op = b_merge.op.inputs[0].op + // enter_grad = enter_grad_op.inputs[0] + // grad_shape = array_ops.shape_internal(enter_grad, optimize=False) + // grad_val = array_ops.zeros(grad_shape) + // if outer_grad_ctxt: + // outer_grad_ctxt.Exit() + // # Use the zeros for iterations > 0. + // grad_state.grad_context.Enter() + // next_grad_val = _NextIteration(grad_val) + // grad_state.grad_context.Exit() + // b_merge.op._update_input(1, next_grad_val) + // # pylint: enable=protected-access + + } + + + + +} diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/GradLoopState.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/GradLoopState.cs new file mode 100644 index 00000000..e8fda1a0 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/GradLoopState.cs @@ -0,0 +1,398 @@ +using System; +using System.Collections; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Operations.ControlFlows +{ + public class GradLoopState + { + + //class GradLoopState(object): + // """The state used for constructing the gradient graph for a while loop. + + // We create a GradLoopState for each while loop in forward and its + // corresponding while loop in backprop. This gives us access to both + // the forward and the backprop WhileContexts. + + // During the construction of gradient graph, any time when we detect + // a forward value that is needed for backprop, we create a history + // accumulator and add it to `history_map`. Any time when we backprop + // a loop switch op (in _SwitchGrad), we add the grad merge op in + // `switch_map`. + // """ + + // def __init__(self, forward_ctxt, outer_grad_state): + // # The grad loop state for the outer while loop. + // self._outer_grad_state = None + + // # The while loop context for forward. + // self._forward_context = None + + // # The loop counter added by AddForwardLoopCounter. It is the value + // # of the loop counter for the next iteration. + // self._forward_index = None + + // # A sync op for forward. + // self._forward_sync = None + + // # The while loop context for backprop. + private WhileContext _grad_context = null; + + public WhileContext grad_context => _grad_context; + + // # The loop counter added by AddBackpropLoopCounter. It is the value + // # of the loop counter for the current iteration. + // self._grad_index = None + + // # A sync op for backprop. + // self._grad_sync = None + + // # Information needed by backprop. + private Hashtable _history_map = new Hashtable(); + public Hashtable history_map => _history_map; + private Hashtable _switch_map = new Hashtable(); + public Hashtable switch_map => _switch_map; + // self._unused_exits = [] + // self._deferred_exits = [] + // self._forward_loop_exits = list(forward_ctxt.loop_exits) + // self._pending_exits_count = len(forward_ctxt.loop_exits) + + // self._outer_grad_state = outer_grad_state + // if outer_grad_state: + // outer_forward_ctxt = outer_grad_state.forward_context + // else: + // if not hasattr(forward_ctxt, "outer_context"): + // raise ValueError("Failed to call gradients on a while loop without" + // "properly serializing graph via MetaGraphDef") + // outer_forward_ctxt = forward_ctxt.outer_context + + // # Add the forward loop counter. + // with forward_ctxt._graph.as_default(): # pylint: disable=protected-access + // if outer_forward_ctxt: + // outer_forward_ctxt.Enter() + // cnt, forward_index = forward_ctxt.AddForwardLoopCounter(outer_grad_state) + // if outer_forward_ctxt: + // outer_forward_ctxt.Exit() + // self._forward_context = forward_ctxt + // self._forward_index = forward_index + + // # Add the backprop WhileContext, and the backprop loop counter. + // if outer_grad_state: + // # This is a nested loop. Remember the iteration counts for each + // # execution of this inner loop. + // outer_forward_ctxt.AddName(cnt.name) + // history_cnt = outer_grad_state.AddForwardAccumulator(cnt) + + // outer_grad_ctxt = outer_grad_state.grad_context + // outer_grad_ctxt.Enter() + // self._grad_context = WhileContext( + // maximum_iterations=forward_ctxt.maximum_iterations, + // parallel_iterations=forward_ctxt.parallel_iterations, + // back_prop=forward_ctxt.back_prop, + // swap_memory=forward_ctxt.swap_memory, + // name=forward_ctxt.name, + // grad_state=self) + // real_cnt = outer_grad_state.AddBackpropAccumulatedValue(history_cnt, cnt) + // self._grad_index = self._grad_context.AddBackpropLoopCounter( + // real_cnt, outer_grad_state) + // outer_grad_ctxt.Exit() + // else: + // if outer_forward_ctxt: + // outer_forward_ctxt.Enter() + // self._grad_context = WhileContext( + // maximum_iterations=forward_ctxt.maximum_iterations, + // parallel_iterations=forward_ctxt.parallel_iterations, + // back_prop=forward_ctxt.back_prop, + // swap_memory=forward_ctxt.swap_memory, + // name=forward_ctxt.name, + // grad_state=self) + // self._grad_index = self._grad_context.AddBackpropLoopCounter( + // cnt, outer_grad_state) + // if outer_forward_ctxt: + // outer_forward_ctxt.Exit() + + // @property + // def outer_grad_state(self): + // """The grad loop state for outer loop.""" + // return self._outer_grad_state + + // @property + // def forward_context(self): + // """The while loop context for forward.""" + // return self._forward_context + + // @property + // def forward_index(self): + // """The loop index of forward loop.""" + // return self._forward_index + + // @property + // def forward_sync(self): + // """A control trigger node for synchronization in the forward loop. + + // One main use is to keep the push ops of a stack executed in the + // iteration order. + // """ + // if self._forward_sync is None: + // with ops.control_dependencies(None): + // self._forward_sync = control_trigger(name="f_sync") + // self._forward_sync._set_control_flow_context(self._forward_context) + // self._forward_index.op._add_control_input(self._forward_sync) + // return self._forward_sync + + // @property + // def grad_context(self): + // """The corresponding WhileContext for gradient.""" + // return self._grad_context + + // @property + // def grad_index(self): + // """The loop index of backprop loop.""" + // return self._grad_index + + // @property + // def grad_sync(self): + // """A control trigger node for synchronization in the grad loop. + + // One main use is to keep the pop ops of a stack executed in the + // iteration order. + // """ + // if self._grad_sync is None: + // with ops.control_dependencies(None): + // self._grad_sync = control_trigger(name="b_sync") + // self._grad_sync._set_control_flow_context(self._grad_context) + // self._grad_index.op._add_control_input(self._grad_sync) + // if self._grad_context.outer_context: + // self._grad_context.outer_context.AddInnerOp(self._grad_sync) + // return self._grad_sync + + // @property + // def history_map(self): + // """The map that records all the tensors needed for backprop.""" + // return self._history_map + + // @property + // def switch_map(self): + // """The map that records all the Switch ops for the while loop.""" + // return self._switch_map + + // @property + // def unused_exits(self): + // """The list of "unused" exits.""" + // return self._unused_exits + + // @property + // def deferred_exits(self): + // """The list of "deferred" exits.""" + // return self._deferred_exits + + // @property + // def forward_loop_exits(self): + // """The list of exits of the forward loop.""" + // return self._forward_loop_exits + + // @property + // def pending_exits_count(self): + // """The number of exits we expect to see but haven't.""" + // return self._pending_exits_count + + // @pending_exits_count.setter + // def pending_exits_count(self, cnt): + // """Set the pending count to cnt.""" + // self._pending_exits_count = cnt + + /// + /// Add an accumulator for each forward tensor that is needed in backprop. + /// + /// This is added to the forward loop at the first time when a tensor + /// in the forward loop is used by backprop gradient computation loop. + /// We create an accumulator that accumulates the value of tensor at each + /// iteration. Called in the control flow context where gradients() is called. + /// + /// The pseudocode is: + /// ``` + /// acc = stack(); + /// while (_pivot) { + /// acc = stack_push(acc, value); + /// } + /// ``` + /// + /// We make sure that the stack push op in one iteration is executed before + /// next iteration. This is achieved by adding a control edge from + /// `forward_index.op.inputs[0].op` to the push op, and another control + /// edge from the push op to either `forward_index.op` or `forward_sync`. + /// + /// The source tensor in forward that is to be accumulated. + /// True iff the tensor is on a dead branch of a cond. + /// The stack that contains the accumulated history of the tensor. + public Tensor AddForwardAccumulator(Tensor value, bool dead_branch = false) + { + throw new NotImplementedException("AddForwardAccumulator"); + // # curr_ctxt is the context that tf.gradients was called in. + // with self._forward_index.graph.as_default(): + // curr_ctxt = ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access + // with ops.control_dependencies(None): + // if curr_ctxt: + // curr_ctxt.Enter() + // with ops.colocate_with(value): + // # We only need to pass maximum_iterations to the stack if + // # we're inside an XLA context. + // if not util.IsInXLAContext(value.op): + // max_size = constant_op.constant(-1, dtypes.int32) + // else: + // max_size = GetMaxSizeFromNestedMaximumIterations( + // value, self.forward_context) + // acc = gen_data_flow_ops.stack_v2( + // max_size=max_size, elem_type=value.dtype.base_dtype, name="f_acc") + // if curr_ctxt: + // curr_ctxt.Exit() + + // # Make acc available in the forward context. + // enter_acc = self.forward_context.AddValue(acc) + + // # Add the stack_push op in the context of value.op. + // swap_enabled = self.forward_context.swap_memory + // value_ctxt = util.GetOutputContext(value.op) + // if value_ctxt == self.forward_context: + // # value is not nested in the forward context. + // self.forward_context.Enter() + // push = gen_data_flow_ops.stack_push_v2( + // enter_acc, value, swap_memory=swap_enabled) + // self.forward_context.Exit() + // # Protect stack push and order it before forward_index. + // self.forward_index.op._add_control_input(push.op) + // else: + // # value is in a cond context within the forward context. + // if not isinstance(value_ctxt, CondContext): + // raise TypeError("value_ctxt is not a CondContext: %s" % value_ctxt) + // if dead_branch: + // # The special case for creating a zero tensor for a dead + // # branch of a switch. See ControlFlowState.ZerosLike(). + // value_ctxt.outer_context.Enter() + // push = gen_data_flow_ops.stack_push_v2( + // enter_acc, value, swap_memory=swap_enabled) + // value_ctxt.outer_context.Exit() + // push.op._set_control_flow_context(value_ctxt) + // else: + // value_ctxt.Enter() + // push = gen_data_flow_ops.stack_push_v2( + // enter_acc, value, swap_memory=swap_enabled) + // value_ctxt.Exit() + // # Protect stack push and order it before forward_sync. + // self.forward_sync._add_control_input(push.op) + // # Order stack push after the successor of forward_index + // add_op = self.forward_index.op.inputs[0].op + // push.op._add_control_input(add_op) + // return acc + } + + // """Add the getter for an accumulated value in the grad context. + // + // This is added to the backprop loop. Called in the grad context to + // get the value of an accumulated value. The stack pop op must be guarded + // by the pred of the controlling cond. + // + // Args: + // history_value: The history (a stack) of a value. + // value: The value that is pushed onto the stack. + // dead_branch: True iff the tensor is on a dead branch of a cond. + // + // Returns: + // The current value (the top of the stack). + // """ + public Tensor AddBackpropAccumulatedValue(Tensor history_value, Tensor value, bool dead_branch= false) + { + throw new NotImplementedException(); + // history_ctxt = history_value.op._get_control_flow_context() + // # Find the cond context that controls history_value if any. + // cond_ctxt = None + // value_ctxt = value.op._get_control_flow_context() + // while value_ctxt and value_ctxt != history_ctxt: + // if isinstance(value_ctxt, CondContext): + // cond_ctxt = value_ctxt + // break + // value_ctxt = value_ctxt.outer_context + // with ops.control_dependencies(None): + // self.grad_context.Enter() + // if cond_ctxt: + // # Guard stack pop with a switch if it is controlled by a cond. + // grad_state = self + // pred = None + // while pred is None and grad_state: + // pred = grad_state.history_map.get(cond_ctxt.pred.name) + // grad_state = grad_state.outer_grad_state + // if pred is None: + // pred = cond_ctxt.pred + // branch = (1 - cond_ctxt.branch) if dead_branch else cond_ctxt.branch + // history_value = _SwitchRefOrTensor(history_value, pred)[branch] + // pop = gen_data_flow_ops.stack_pop_v2(history_value, + // value.dtype.base_dtype) + // pop.set_shape(value.get_shape()) + // self.grad_context.Exit() + // parallel_iterations = self.grad_context.parallel_iterations + // if parallel_iterations > 1: + // # All pops are ordered after pivot_for_body and before grad_sync. + // self.grad_sync._add_control_input(pop.op) + // return pop + } + + // def GetRealValue(self, value): + // """Get the real value of `value`. + + // If backprop "uses" a value produced by forward inference, an accumulator + // is added in the forward loop to accumulate its values. We use the + // accumulated value. This method must be called in the grad loop context. + // `value` must be in forward and needed for backprop. + + // Args: + // value: A tensor to be captured. + + // Returns: + // The same tensor obtained from the saved history. + // """ + // assert value.op.type not in ["Variable", "VariableV2"] + // real_value = self._history_map.get(value.name) + // if real_value is None: + // cur_value = value + // cur_grad_state = self + // while True: + // enter_op = util.GetLoopConstantEnter(cur_value) + // if enter_op: + // # Special case: cur_value comes from a constant Enter node. + // cur_value = enter_op.inputs[0] + // cur_grad_state = cur_grad_state.outer_grad_state + // if cur_grad_state is None: + // # We are now outside all nested loops for this gradient(), + // # so `value` is a loop invariant and there is no need to + // # save the history of value. Just make cur_value to enter + // # the right control flow context. + // real_value = self._grad_context.AddValue(cur_value) + // break + // elif constant_op.is_constant(cur_value): + // # If the value to be forwarded is a constant, clone the constant in + // # the gradient loop rather than using a stack. + // # TODO(phawkins): consider hoisting the constant out of the loop + // # instead. + // real_value = constant_op.constant( + // tensor_util.constant_value(cur_value), dtype=cur_value.dtype) + // break + // else: + // # Record the history of this value in forward_ctxt. + // self._grad_context.Exit() + // history_value = cur_grad_state.AddForwardAccumulator(cur_value) + // self._grad_context.Enter() + // break + + // if real_value is None: + // # Add the stack pop op in the grad context. + // real_value = cur_grad_state.AddBackpropAccumulatedValue( + // history_value, cur_value) + // if cur_grad_state != self: + // real_value = self._grad_context.AddValue(real_value) + // self._history_map[value.name] = real_value + // return real_value + + + } +} diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/IControlFlowContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/IControlFlowContext.cs index 7fdd22f5..f9dde8c4 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/IControlFlowContext.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/IControlFlowContext.cs @@ -4,13 +4,15 @@ using System.Text; namespace Tensorflow { - public interface IControlFlowContext - { - void AddOp(Operation op); - IControlFlowContext outer_context { get; } - HashSet values { get; } - Tensor AddValue(Tensor val); - void AddInnerOp(Operation resultOp); - object to_proto(); - } + // henon: this was too much trouble. there is no value just cost to use an interface here. + //public interface IControlFlowContext + //{ + // void AddOp(Operation op); + // IControlFlowContext outer_context { get; } + // HashSet values { get; } + // Tensor pivot { get; } + // Tensor AddValue(Tensor val); + // void AddInnerOp(Operation resultOp); + // object to_proto(); + //} } diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs index d800679b..966ac83f 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs @@ -1,11 +1,26 @@ using System; using System.Collections.Generic; using System.Text; +using Tensorflow.Operations.ControlFlows; namespace Tensorflow.Operations { public class WhileContext : ControlFlowContext { + private bool _back_prop=true; + + private GradLoopState _grad_state =null; + + public override WhileContext GetWhileContext() + { + return this; + } + + + public override GradLoopState grad_state => _grad_state; + + public override bool back_prop => _back_prop; + public static WhileContext from_proto(object proto) { throw new NotImplementedException(); diff --git a/src/TensorFlowNET.Core/Operations/Operation.Control.cs b/src/TensorFlowNET.Core/Operations/Operation.Control.cs index 262d8e75..9b3aefe2 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.Control.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.Control.cs @@ -7,7 +7,7 @@ namespace Tensorflow { public partial class Operation { - private IControlFlowContext _control_flow_context; + private ControlFlowContext _control_flow_context; /// /// Add this op to its control flow context. @@ -39,12 +39,12 @@ namespace Tensorflow _add_control_input(op); } - public void _set_control_flow_context(IControlFlowContext ctx) + public void _set_control_flow_context(ControlFlowContext ctx) { _control_flow_context = ctx; } - public IControlFlowContext _get_control_flow_context() + public ControlFlowContext _get_control_flow_context() { return _control_flow_context; } diff --git a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs index 11950b46..08b8c8b5 100644 --- a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.Linq; using System.Text; using Tensorflow.Operations; +using Tensorflow.Operations.ControlFlows; using util = Tensorflow.control_flow_util; namespace Tensorflow @@ -93,9 +94,9 @@ namespace Tensorflow /// /// /// - public static object MaybeCreateControlFlowState(List between_op_list, List between_ops, bool colocate_gradients_with_ops) + public static ControlFlowState MaybeCreateControlFlowState(List between_op_list, List between_ops, bool colocate_gradients_with_ops) { - object loop_state = null; + ControlFlowState loop_state = null; foreach (var op in between_op_list) { @@ -103,7 +104,7 @@ namespace Tensorflow { if(loop_state == null) { - // loop_state = ControlFlowState(); + loop_state = new ControlFlowState(); } } } @@ -207,7 +208,7 @@ namespace Tensorflow /// `(output_false, output_true)`: If `pred` is true, data will be forwarded to /// `output_true`, otherwise it goes to `output_false`. /// - public static (Tensor, Tensor) _SwitchRefOrTensor(Tensor data, Tensor pred, string name = "Switch") + public static Tensor[] _SwitchRefOrTensor(Tensor data, Tensor pred, string name = "Switch") { data = ops.convert_to_tensor_or_indexed_slices(data, name: "data"); // NOTE(vrv): ops.colocate_with(data, ignore_existing=True) below @@ -298,7 +299,9 @@ namespace Tensorflow */ // Add the Switch to the graph. - var (p_2, p_1) = @switch(pred, pred); + var switch_result= @switch(pred, pred); + var p_2=switch_result[0]; + var p_1 = switch_result[1]; var pivot_1 = array_ops.identity(p_1, name: "switch_t"); var pivot_2 = array_ops.identity(p_2, name: "switch_f"); pred = array_ops.identity(pred, name: "pred_id"); @@ -379,7 +382,9 @@ namespace Tensorflow return with(ops.name_scope(name, "cond", new { pred }), delegate { // Add the Switch to the graph. - var (p_2, p_1) = @switch(pred, pred); + var switch_result = @switch(pred, pred); + var p_2 = switch_result[0]; + var p_1 = switch_result[1]; var pivot_1 = array_ops.identity(p_1, name: "switch_t"); var pivot_2 = array_ops.identity(p_2, name: "switch_f"); pred = array_ops.identity(pred, name: "pred_id"); @@ -460,7 +465,7 @@ namespace Tensorflow /// /// /// - public static (Tensor, Tensor) @switch(Tensor data, + public static Tensor[] @switch(Tensor data, Tensor pred, TF_DataType dtype = TF_DataType.DtInvalid, string name = null) diff --git a/src/TensorFlowNET.Core/Operations/control_flow_util.py.cs b/src/TensorFlowNET.Core/Operations/control_flow_util.py.cs index 98ccbb06..5e2fc43e 100644 --- a/src/TensorFlowNET.Core/Operations/control_flow_util.py.cs +++ b/src/TensorFlowNET.Core/Operations/control_flow_util.py.cs @@ -30,7 +30,7 @@ namespace Tensorflow /// /// Return the control flow context for the output of an op. /// - public static IControlFlowContext GetOutputContext(Operation op) + public static ControlFlowContext GetOutputContext(Operation op) { var ctxt = op._get_control_flow_context(); // Exit nodes usually have a control flow context, except in the case where the diff --git a/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs b/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs index 31e2cad3..78e70053 100644 --- a/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs @@ -33,14 +33,14 @@ namespace Tensorflow /// output_false: A `Tensor`. Has the same type as `data`. /// output_true: A `Tensor`. Has the same type as `data`. /// - public static (Tensor, Tensor) @switch(Tensor data, Tensor pred, string name = null) + public static Tensor[] @switch(Tensor data, Tensor pred, string name = null) { var _op = _op_def_lib._apply_op_helper("Switch", name, new { data, pred }); var _inputs_flat = _op.inputs; var _attrs = ("T", _op.get_attr("T")); // TODO: missing original code //_execute.record_gradient("Switch", _inputs_flat, _attrs, _result, name); - return (_op.outputs[0], _op.outputs[1]); + return new []{_op.outputs[0], _op.outputs[1]}; } public static (Tensor, Tensor) merge(Tensor[] inputs, string name = null)