@@ -1,12 +1,79 @@ | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Linq; | |||||
using System.Text; | using System.Text; | ||||
using Tensorflow.Operations; | using Tensorflow.Operations; | ||||
namespace Tensorflow.Gradients | namespace Tensorflow.Gradients | ||||
{ | { | ||||
/// <summary> | |||||
/// Gradients for operators defined in control_flow_ops.py.cs | |||||
/// </summary> | |||||
public class control_flow_grad | public class control_flow_grad | ||||
{ | { | ||||
/// <summary> | |||||
/// Gradients for a Switch op is calculated using a Merge op. | |||||
/// | |||||
/// If the switch is a loop switch, it will be visited twice. We create | |||||
/// the merge on the first visit, and update the other input of the merge | |||||
/// on the second visit. A next_iteration is also added on second visit. | |||||
/// </summary> | |||||
/// <returns></returns> | |||||
public Tensor[] _SwitchGrad(Tensor op, Tensor[] grads) | |||||
{ | |||||
throw new NotImplementedException("_SwitchGrad"); | |||||
//graph = ops.get_default_graph() | |||||
//# pylint: disable=protected-access | |||||
//op_ctxt = op._get_control_flow_context() | |||||
//grad_ctxt = graph._get_control_flow_context() | |||||
//# pylint: enable=protected-access | |||||
//if isinstance(op_ctxt, WhileContext): | |||||
// merge_grad = grad_ctxt.grad_state.switch_map.get(op) | |||||
// if merge_grad is not None: | |||||
// # This is the second time this Switch is visited. It comes from | |||||
// # the non-exit branch of the Switch, so update the second input | |||||
// # to the Merge. | |||||
// # TODO(yuanbyu): Perform shape inference with this new input. | |||||
// if grad[1] is not None: | |||||
// # pylint: disable=protected-access | |||||
// control_flow_ops._AddNextAndBackEdge(merge_grad, grad[1], | |||||
// enforce_shape_invariant=False) | |||||
// # pylint: enable=protected-access | |||||
// return None, None | |||||
// elif grad[0] is not None: | |||||
// # This is the first time this Switch is visited. It comes from | |||||
// # the Exit branch, which is grad[0]. grad[1] is empty at this point. | |||||
// # Use grad[0] for both inputs to merge for now, but update the second | |||||
// # input of merge when we see this Switch the second time. | |||||
// merge_grad = merge([grad[0], grad[0]], name="b_switch")[0] | |||||
// grad_ctxt.grad_state.switch_map[op] = merge_grad | |||||
// return merge_grad, None | |||||
// else: | |||||
// # This is the first time this Switch is visited. It comes from the | |||||
// # Identity branch. Such a Switch has `None` gradient for the Exit branch, | |||||
// # meaning the output is not differentiable. | |||||
// return None, None | |||||
//elif isinstance(op_ctxt, CondContext): | |||||
// zero_grad = grad[1 - op_ctxt.branch] | |||||
// # At this point, we have created zero_grad guarded by the right switch. | |||||
// # Unfortunately, we may still get None here for not trainable data types. | |||||
// if zero_grad is None: | |||||
// # For resource variables we get None always on the other branch, so bypass | |||||
// # this. | |||||
// if op.inputs[0].dtype == dtypes.resource: | |||||
// return merge( | |||||
// [grad[op_ctxt.branch]] * 2, name="cond_resource_grad")[0], None | |||||
// return None, None | |||||
// return merge(grad, name="cond_grad")[0], None | |||||
//else: | |||||
// false_grad = switch(grad[0], op.inputs[1])[0] | |||||
// true_grad = switch(grad[1], op.inputs[1])[1] | |||||
// return merge([false_grad, true_grad])[0], None | |||||
} | |||||
/// <summary> | |||||
/// Gradients for a Merge op are calculated using a Switch op. | |||||
/// </summary> | |||||
public static Tensor[] _MergeGrad(Operation op, Tensor[] grads) | public static Tensor[] _MergeGrad(Operation op, Tensor[] grads) | ||||
{ | { | ||||
var grad = grads[0]; | var grad = grads[0]; | ||||
@@ -14,10 +81,164 @@ namespace Tensorflow.Gradients | |||||
var input_op = op.inputs[0].op; | var input_op = op.inputs[0].op; | ||||
var graph = ops.get_default_graph(); | var graph = ops.get_default_graph(); | ||||
var op_ctxt = control_flow_util.GetOutputContext(input_op); | var op_ctxt = control_flow_util.GetOutputContext(input_op); | ||||
var pred = (op_ctxt as CondContext).pred; | |||||
var grad_ctxt = graph._get_control_flow_context(); | |||||
switch (op_ctxt) | |||||
{ | |||||
case WhileContext cwhile: | |||||
{ | |||||
return control_flow_ops._SwitchRefOrTensor(grad, grad_ctxt.pivot); | |||||
} | |||||
case CondContext ccond: | |||||
{ | |||||
var pred = ccond.pred; | |||||
if (grad_ctxt != null && grad_ctxt.grad_state != null) | |||||
{ | |||||
//# This Merge node is part of a cond within a loop. | |||||
//# The backprop needs to have the value of this predicate for every | |||||
//# iteration. So we must have its values accumulated in the forward, and | |||||
//# use the accumulated values as the predicate for this backprop switch. | |||||
var grad_state = grad_ctxt.grad_state; | |||||
var real_pred = grad_state.history_map[pred.name] as Tensor; | |||||
if (real_pred == null) | |||||
{ | |||||
//# Remember the value of pred for every iteration. | |||||
grad_ctxt = grad_state.grad_context; | |||||
grad_ctxt.Exit(); | |||||
var history_pred = grad_state.AddForwardAccumulator(pred); | |||||
grad_ctxt.Enter(); | |||||
//# Add the stack pop op. If pred.op is in a (outer) CondContext, | |||||
//# the stack pop will be guarded with a switch. | |||||
real_pred = grad_state.AddBackpropAccumulatedValue(history_pred, pred); | |||||
grad_state.history_map[pred.name] = real_pred; | |||||
} | |||||
pred = real_pred; | |||||
} | |||||
var results = control_flow_ops._SwitchRefOrTensor(grad, pred, name: "cond_grad"); | |||||
return results; | |||||
} | |||||
default: | |||||
{ | |||||
var num_inputs = op.inputs.Length; | |||||
var cond = new Tensor[num_inputs]; | |||||
for (int i = 0; i < num_inputs; i++) | |||||
cond[i] = math_ops.equal(op.outputs[1], i); | |||||
var result = cond.Select(t => control_flow_ops._SwitchRefOrTensor(grad, t)[1]).ToArray(); | |||||
return result; | |||||
} | |||||
} | |||||
var results = control_flow_ops._SwitchRefOrTensor(grad, pred, name: "cond_grad"); | |||||
return new Tensor[] { results.Item1, results.Item2 }; | |||||
} | } | ||||
} | |||||
public Tensor[] _RefMergeGrad(Operation op, Tensor[] grads) | |||||
{ | |||||
return _MergeGrad(op, grads); | |||||
} | |||||
/// <summary> | |||||
/// Gradients for an exit op are calculated using an Enter op. | |||||
/// </summary> | |||||
public Tensor[] _ExitGrad(Operation op, Tensor[] grads) | |||||
{ | |||||
throw new NotImplementedException("_ExitGrad"); | |||||
// graph = ops.get_default_graph() | |||||
//# pylint: disable=protected-access | |||||
// op_ctxt = op._get_control_flow_context() | |||||
// grad_ctxt = graph._get_control_flow_context() | |||||
// # pylint: enable=protected-access | |||||
// if not grad_ctxt.back_prop: | |||||
// # The flag `back_prop` is set by users to suppress gradient | |||||
// # computation for this loop. If the attribute `back_prop` is false, | |||||
// # no gradient computation. | |||||
// return None | |||||
// if op_ctxt.grad_state: | |||||
// raise TypeError("Second-order gradient for while loops not supported.") | |||||
// if isinstance(grad, ops.Tensor) : | |||||
// grad_ctxt.AddName(grad.name) | |||||
// else: | |||||
// if not isinstance(grad, (ops.IndexedSlices, sparse_tensor.SparseTensor)): | |||||
// raise TypeError("Type %s not supported" % type(grad)) | |||||
// grad_ctxt.AddName(grad.values.name) | |||||
// grad_ctxt.AddName(grad.indices.name) | |||||
// dense_shape = grad.dense_shape | |||||
// if dense_shape is not None: | |||||
// grad_ctxt.AddName(dense_shape.name) | |||||
// grad_ctxt.Enter() | |||||
// # pylint: disable=protected-access | |||||
// result = control_flow_ops._Enter( | |||||
// grad, grad_ctxt.name, is_constant=False, | |||||
// parallel_iterations=grad_ctxt.parallel_iterations, | |||||
// name="b_exit") | |||||
// # pylint: enable=protected-access | |||||
// grad_ctxt.loop_enters.append(result) | |||||
// grad_ctxt.Exit() | |||||
// return result | |||||
} | |||||
/// <summary> | |||||
/// A forward next_iteration is translated into a backprop identity. | |||||
/// | |||||
/// Note that the backprop next_iteration is added in switch grad. | |||||
/// </summary> | |||||
public (object, Tensor[]) _NextIterationGrad(object _, Tensor[] grad) | |||||
{ | |||||
return (_, grad); | |||||
} | |||||
public (object, Tensor[]) _RefNextIterationGrad(object _, Tensor[] grad) | |||||
{ | |||||
return (_, grad); | |||||
} | |||||
/// <summary> | |||||
/// Gradients for an Enter are calculated using an Exit op. | |||||
/// | |||||
/// For loop variables, grad is the gradient so just add an exit. | |||||
/// For loop invariants, we need to add an accumulator loop. | |||||
/// </summary> | |||||
public (object, Tensor[]) _EnterGrad(Tensor op, Tensor[] grad) | |||||
{ | |||||
throw new NotImplementedException("_EnterGrad"); | |||||
// graph = ops.get_default_graph() | |||||
//# pylint: disable=protected-access | |||||
// grad_ctxt = graph._get_control_flow_context() | |||||
// # pylint: enable=protected-access | |||||
// if not grad_ctxt.back_prop: | |||||
// # Skip gradient computation, if the attribute `back_prop` is false. | |||||
// return grad | |||||
// if grad_ctxt.grad_state is None: | |||||
// # Pass the gradient through if we are not in a gradient while context. | |||||
// return grad | |||||
// if op.get_attr("is_constant"): | |||||
// # Add a gradient accumulator for each loop invariant. | |||||
// if isinstance(grad, ops.Tensor) : | |||||
// result = grad_ctxt.AddBackpropAccumulator(op, grad) | |||||
// elif isinstance(grad, ops.IndexedSlices) : | |||||
// result = grad_ctxt.AddBackpropIndexedSlicesAccumulator(op, grad) | |||||
// else: | |||||
// # TODO(yuanbyu, lukasr): Add support for SparseTensor. | |||||
// raise TypeError("Type %s not supported" % type(grad)) | |||||
// else: | |||||
// result = exit(grad) | |||||
// grad_ctxt.loop_exits.append(result) | |||||
// grad_ctxt.ExitResult([result]) | |||||
// return result | |||||
} | |||||
public (object, Tensor[]) _RefEnterGrad(Tensor op, Tensor[] grad) | |||||
{ | |||||
return _EnterGrad(op, grad); | |||||
} | |||||
/// <summary> | |||||
/// Stop backprop for the predicate of a while loop. | |||||
/// </summary> | |||||
public object _LoopCondGrad(object _) | |||||
{ | |||||
return null; | |||||
} | |||||
} | |||||
} | } |
@@ -3,13 +3,14 @@ using System.Collections.Generic; | |||||
using System.Linq; | using System.Linq; | ||||
using System.Text; | using System.Text; | ||||
using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
using Tensorflow.Operations; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
public partial class Graph | public partial class Graph | ||||
{ | { | ||||
// Current control flow context. It could be either CondContext or WhileContext | // Current control flow context. It could be either CondContext or WhileContext | ||||
public IControlFlowContext _control_flow_context; | |||||
public ControlFlowContext _control_flow_context; | |||||
// represents the nested with(...) statements | // represents the nested with(...) statements | ||||
public List<_ControlDependenciesController> _control_dependencies_stack { get; set; } = new List<_ControlDependenciesController>(); | public List<_ControlDependenciesController> _control_dependencies_stack { get; set; } = new List<_ControlDependenciesController>(); | ||||
@@ -97,7 +98,7 @@ namespace Tensorflow | |||||
/// Returns the current control flow context. | /// Returns the current control flow context. | ||||
/// </summary> | /// </summary> | ||||
/// <returns>A context object.</returns> | /// <returns>A context object.</returns> | ||||
public IControlFlowContext _get_control_flow_context() | |||||
public ControlFlowContext _get_control_flow_context() | |||||
{ | { | ||||
return _control_flow_context; | return _control_flow_context; | ||||
} | } | ||||
@@ -106,7 +107,7 @@ namespace Tensorflow | |||||
/// Sets the current control flow context. | /// Sets the current control flow context. | ||||
/// </summary> | /// </summary> | ||||
/// <param name="ctx">a context object.</param> | /// <param name="ctx">a context object.</param> | ||||
public void _set_control_flow_context(IControlFlowContext ctx) | |||||
public void _set_control_flow_context(ControlFlowContext ctx) | |||||
{ | { | ||||
_control_flow_context = ctx; | _control_flow_context = ctx; | ||||
} | } | ||||
@@ -2,6 +2,7 @@ | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Text; | using System.Text; | ||||
using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
using Tensorflow.Operations; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
@@ -15,7 +16,7 @@ namespace Tensorflow | |||||
private List<ITensorOrOperation> _seen_nodes; | private List<ITensorOrOperation> _seen_nodes; | ||||
private List<_ControlDependenciesController> _old_stack; | private List<_ControlDependenciesController> _old_stack; | ||||
private bool _new_stack; | private bool _new_stack; | ||||
private IControlFlowContext _old_control_flow_context; | |||||
private ControlFlowContext _old_control_flow_context; | |||||
public ITensorOrOperation[] control_inputs => _control_inputs_val.ToArray(); | public ITensorOrOperation[] control_inputs => _control_inputs_val.ToArray(); | ||||
@@ -2,6 +2,7 @@ | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Linq; | using System.Linq; | ||||
using System.Text; | using System.Text; | ||||
using Tensorflow.Operations.ControlFlows; | |||||
namespace Tensorflow.Operations | namespace Tensorflow.Operations | ||||
{ | { | ||||
@@ -107,8 +108,8 @@ namespace Tensorflow.Operations | |||||
with(ops.control_dependencies(null), ctrl => | with(ops.control_dependencies(null), ctrl => | ||||
{ | { | ||||
var (r0, r1) = control_flow_ops._SwitchRefOrTensor(result, _pred); | |||||
result = new[] { r0, r1 }[_branch]; | |||||
var results = control_flow_ops._SwitchRefOrTensor(result, _pred); | |||||
result = results[_branch]; | |||||
if (_outer_context != null) | if (_outer_context != null) | ||||
_outer_context.AddInnerOp(result.op); | _outer_context.AddInnerOp(result.op); | ||||
}); | }); | ||||
@@ -118,7 +119,7 @@ namespace Tensorflow.Operations | |||||
// Mark Switch output as seen by this context and any outer contexts, | // Mark Switch output as seen by this context and any outer contexts, | ||||
// just like what we do for normal op outputs in _AddOpInternal() below. | // just like what we do for normal op outputs in _AddOpInternal() below. | ||||
IControlFlowContext ctxt = this; | |||||
ControlFlowContext ctxt = this; | |||||
while (ctxt != null) | while (ctxt != null) | ||||
{ | { | ||||
ctxt.values.Add(result.name); | ctxt.values.Add(result.name); | ||||
@@ -223,8 +224,8 @@ namespace Tensorflow.Operations | |||||
_values.Add(real_val.name); | _values.Add(real_val.name); | ||||
_external_values[real_val.name] = real_val; | _external_values[real_val.name] = real_val; | ||||
} | } | ||||
var (t0, t1) = control_flow_ops._SwitchRefOrTensor(real_val, _pred); | |||||
real_val = new[] {t0, t1}[_branch]; | |||||
var results = control_flow_ops._SwitchRefOrTensor(real_val, _pred); | |||||
real_val = results[_branch]; | |||||
_external_values[val.name] = real_val; | _external_values[val.name] = real_val; | ||||
} | } | ||||
else | else | ||||
@@ -238,8 +239,8 @@ namespace Tensorflow.Operations | |||||
return real_val; | return real_val; | ||||
} | } | ||||
protected override void _AddOpInternal(Operation op) | |||||
{ | |||||
protected override void _AddOpInternal(Operation op) | |||||
{ | |||||
if (op.inputs.Length == 0) | if (op.inputs.Length == 0) | ||||
{ | { | ||||
//If we're in a while loop, remove any control inputs from outside the | //If we're in a while loop, remove any control inputs from outside the | ||||
@@ -282,11 +283,11 @@ namespace Tensorflow.Operations | |||||
// TODO: implement below code dependencies | // TODO: implement below code dependencies | ||||
//if (op.graph._is_function(op.type) || op.type == "SymbolicGradient") | //if (op.graph._is_function(op.type) || op.type == "SymbolicGradient") | ||||
// op._add_control_input(_pivot.op); | // op._add_control_input(_pivot.op); | ||||
} | |||||
// Mark op's outputs as seen by this context and any outer contexts. | |||||
} | |||||
// Mark op's outputs as seen by this context and any outer contexts. | |||||
var output_names = op.outputs.Select(x => x.name).ToArray(); | var output_names = op.outputs.Select(x => x.name).ToArray(); | ||||
IControlFlowContext ctxt = this; | |||||
ControlFlowContext ctxt = this; | |||||
while (ctxt != null) | while (ctxt != null) | ||||
{ | { | ||||
foreach (var name in output_names) | foreach (var name in output_names) | ||||
@@ -298,9 +299,31 @@ namespace Tensorflow.Operations | |||||
op.graph.prevent_fetching(op); | op.graph.prevent_fetching(op); | ||||
if (_outer_context != null) | if (_outer_context != null) | ||||
_outer_context.AddInnerOp(op); | |||||
} | |||||
_outer_context.AddInnerOp(op); | |||||
} | |||||
public override GradLoopState grad_state | |||||
{ | |||||
get | |||||
{ | |||||
var whc = GetWhileContext(); | |||||
if (whc != null) | |||||
return whc.grad_state; | |||||
return null; | |||||
} | |||||
} | |||||
public override bool back_prop | |||||
{ | |||||
get | |||||
{ | |||||
var whc = GetWhileContext(); | |||||
if (whc != null) | |||||
return whc.back_prop; | |||||
return false; | |||||
} | |||||
} | |||||
public CondContextDef to_proto(string export_scope) | public CondContextDef to_proto(string export_scope) | ||||
{ | { | ||||
throw new NotImplementedException(); | throw new NotImplementedException(); | ||||
@@ -2,6 +2,7 @@ | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Linq; | using System.Linq; | ||||
using System.Text; | using System.Text; | ||||
using Tensorflow.Operations.ControlFlows; | |||||
namespace Tensorflow.Operations | namespace Tensorflow.Operations | ||||
{ | { | ||||
@@ -22,21 +23,25 @@ namespace Tensorflow.Operations | |||||
/// 4. A ControlFlowContext has _context_stack. | /// 4. A ControlFlowContext has _context_stack. | ||||
/// Pushed and popped by ctxt.Enter() and ctxt.Exit() | /// Pushed and popped by ctxt.Enter() and ctxt.Exit() | ||||
/// </summary> | /// </summary> | ||||
public abstract class ControlFlowContext : Python, IPython, IControlFlowContext | |||||
public abstract class ControlFlowContext : Python, IPython | |||||
{ | { | ||||
/// <summary> | /// <summary> | ||||
/// The predicate tensor in this branch | /// The predicate tensor in this branch | ||||
/// </summary> | /// </summary> | ||||
protected Tensor _pivot; | protected Tensor _pivot; | ||||
public Tensor pivot | |||||
{ | |||||
get => _pivot; | |||||
} | |||||
protected Stack<IControlFlowContext> _context_stack; | |||||
protected IControlFlowContext _outer_context; | |||||
protected Stack<ControlFlowContext> _context_stack; | |||||
protected ControlFlowContext _outer_context; | |||||
protected Dictionary<string, ITensorOrOperation> _external_values; | protected Dictionary<string, ITensorOrOperation> _external_values; | ||||
public ControlFlowContext() | public ControlFlowContext() | ||||
{ | { | ||||
_context_stack = new Stack<IControlFlowContext>(); | |||||
_context_stack = new Stack<ControlFlowContext>(); | |||||
} | } | ||||
public string name { get => _name; } | public string name { get => _name; } | ||||
@@ -111,8 +116,13 @@ namespace Tensorflow.Operations | |||||
_AddOpInternal(op); | _AddOpInternal(op); | ||||
} | } | ||||
public IControlFlowContext outer_context { get { return _outer_context; } } | |||||
public ControlFlowContext outer_context { get { return _outer_context; } } | |||||
public HashSet<string> values => _values; | public HashSet<string> values => _values; | ||||
public virtual GradLoopState grad_state => throw new NotImplementedException("abstract method"); | |||||
public virtual bool back_prop => throw new NotImplementedException("abstract method"); | |||||
public virtual Tensor AddValue(Tensor val) | public virtual Tensor AddValue(Tensor val) | ||||
{ | { | ||||
// to be overridden | // to be overridden | ||||
@@ -147,7 +157,7 @@ namespace Tensorflow.Operations | |||||
/// <summary> | /// <summary> | ||||
/// Returns true if `maybe_containing_ctxt` is or contains `ctxt`. | /// Returns true if `maybe_containing_ctxt` is or contains `ctxt`. | ||||
/// </summary> | /// </summary> | ||||
public static bool IsContainingContext(IControlFlowContext ctxt, ControlFlowContext maybe_containing_ctxt) | |||||
public static bool IsContainingContext(ControlFlowContext ctxt, ControlFlowContext maybe_containing_ctxt) | |||||
{ | { | ||||
while (ctxt != maybe_containing_ctxt) | while (ctxt != maybe_containing_ctxt) | ||||
{ | { | ||||
@@ -164,6 +174,16 @@ namespace Tensorflow.Operations | |||||
var internal_control_inputs = op.control_inputs; | var internal_control_inputs = op.control_inputs; | ||||
} | } | ||||
/// <summary> | |||||
/// Return the while context containing this context | |||||
/// </summary> | |||||
public virtual WhileContext GetWhileContext() | |||||
{ | |||||
if (_outer_context != null) | |||||
return _outer_context.GetWhileContext(); | |||||
return null; | |||||
} | |||||
public object to_proto() | public object to_proto() | ||||
{ | { | ||||
throw new NotImplementedException(); | throw new NotImplementedException(); | ||||
@@ -173,5 +193,6 @@ namespace Tensorflow.Operations | |||||
public void Dispose() | public void Dispose() | ||||
{ | { | ||||
} | } | ||||
} | } | ||||
} | } |
@@ -0,0 +1,277 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Operations.ControlFlows | |||||
{ | |||||
/// <summary> | |||||
/// Maintain the mapping from the loops to their grad states. | |||||
/// </summary> | |||||
public class ControlFlowState | |||||
{ | |||||
//class ControlFlowState(object): | |||||
// """Maintain the mapping from the loops to their grad states.""" | |||||
// def __init__(self): | |||||
// self._map = {} # maps forward loop context to GradLoopState | |||||
// def GetGradState(self, op, before): | |||||
// """Return the grad state for this op if it's in a forward loop context.""" | |||||
// if before and util.IsLoopExit(op): | |||||
// forward_ctxt = op._get_control_flow_context() | |||||
// forward_ctxt = forward_ctxt.outer_context | |||||
// if forward_ctxt: | |||||
// forward_ctxt = forward_ctxt.GetWhileContext() | |||||
// else: | |||||
// forward_ctxt = _GetWhileContext(op) | |||||
// if forward_ctxt: | |||||
// return self._map.get(forward_ctxt) | |||||
// return None | |||||
// def ProcessUnusedLoopExits(self, pending_count, to_ops_set): | |||||
// """Process all the "unused" loop exits. | |||||
// The "unused" exits of the loops are added to `unused_exits`. An exit is | |||||
// unused if its pending_count is 0. If there is an exit with real gradient, | |||||
// all these deferred exits will enter the backprop loop with zero gradient. | |||||
// Otherwise, they will enter the backprop loop with None. As an example, | |||||
// people often write: | |||||
// ```python | |||||
// v1, _ = tf.while_loop(p, b, [x1, x2]) | |||||
// result = gradients(v1, x1) | |||||
// ``` | |||||
// The exit node for x2 is not included by the betweenness analysis. But we | |||||
// need to backprop x2 if x2 is involved in computing v1. | |||||
// Args: | |||||
// pending_count: The number of backprop inputs for every op. | |||||
// to_ops_set: The set of ops for ys in gradients(ys, xs) | |||||
// Returns: | |||||
// The set of unused loop exits that we know at this point we need | |||||
// to backprop. | |||||
// """ | |||||
// loop_exits = [] | |||||
// for grad_state in self._map.values(): | |||||
// for y in grad_state.forward_loop_exits: | |||||
// if pending_count[y.op] == 0: | |||||
// grad_state.pending_exits_count -= 1 | |||||
// if y.op not in to_ops_set: | |||||
// grad_state.unused_exits.append(y) | |||||
// if grad_state.pending_exits_count == 0: | |||||
// loop_exits.extend(grad_state.unused_exits) | |||||
// # Need to include Enters in backprop for higher-order gradients. | |||||
// for y in grad_state.forward_context.loop_enters: | |||||
// if pending_count[y.op] == 0: | |||||
// pending_count[y.op] = 1 | |||||
// return loop_exits | |||||
// def EnterGradWhileContext(self, op, before): | |||||
// """Enter the WhileContext for gradient computation.""" | |||||
// grad_state = self.GetGradState(op, before) | |||||
// if grad_state: | |||||
// grad_state.grad_context.Enter() | |||||
// def ExitGradWhileContext(self, op, before): | |||||
// """Exit the WhileContext for gradient computation.""" | |||||
// grad_state = self.GetGradState(op, before) | |||||
// if grad_state: | |||||
// grad_state.grad_context.Exit() | |||||
// def AddWhileContext(self, op, between_op_list, between_ops): | |||||
// """Add the grad state for the while loop that op belongs to. | |||||
// Note that op is an Exit, and this method must be called in | |||||
// the control flow context where gradients() is called. | |||||
// Note that this method modifies `between_op_list` and `between_ops`. | |||||
// """ | |||||
// forward_ctxt = _GetWhileContext(op) | |||||
// grad_state = self._map.get(forward_ctxt) | |||||
// if grad_state is None: | |||||
// # This is a new while loop so create a grad state for it. | |||||
// outer_forward_ctxt = forward_ctxt.outer_context | |||||
// if outer_forward_ctxt: | |||||
// outer_forward_ctxt = outer_forward_ctxt.GetWhileContext() | |||||
// outer_grad_state = None | |||||
// if outer_forward_ctxt: | |||||
// outer_grad_state = self._map.get(outer_forward_ctxt) | |||||
// grad_state = GradLoopState(forward_ctxt, outer_grad_state) | |||||
// self._map[forward_ctxt] = grad_state | |||||
// # We need to include all exits of a loop for backprop. | |||||
// for loop_exit in grad_state.forward_loop_exits: | |||||
// if loop_exit.op not in between_ops: | |||||
// between_ops.add(loop_exit.op) | |||||
// between_op_list.append(loop_exit.op) | |||||
// def ZerosLikeForExit(self, val): | |||||
// """Create zeros_like gradient for a loop exit. | |||||
// If the result of a loop variable is not used but is involved in | |||||
// computing the result of some needed loop variable, we create a | |||||
// zero-valued tensor that is fed as gradient for the Exit node of that | |||||
// loop variable. Note that val.op is an Exit, and this method must be | |||||
// called in the control flow context where gradients() is called. | |||||
// Args: | |||||
// val: The output tensor of an Exit op. | |||||
// Returns: | |||||
// A zero tensor of the same shape of val. | |||||
// """ | |||||
// val_shape = val.get_shape() | |||||
// forward_ctxt = val.op._get_control_flow_context() | |||||
// outer_forward_ctxt = forward_ctxt.outer_context | |||||
// if outer_forward_ctxt: | |||||
// outer_forward_ctxt = outer_forward_ctxt.GetWhileContext() | |||||
// outer_grad_state = None | |||||
// if outer_forward_ctxt: | |||||
// outer_grad_state = self._map.get(outer_forward_ctxt) | |||||
// if outer_grad_state: | |||||
// # This is a nested loop. | |||||
// if val_shape.is_fully_defined(): | |||||
// # If the shape is known statically, just create a zero tensor | |||||
// # with the right shape in the right context. | |||||
// outer_grad_state.grad_context.Enter() | |||||
// result = array_ops.zeros(val_shape.dims, val.dtype) | |||||
// outer_grad_state.grad_context.Exit() | |||||
// else: | |||||
// # Only the shape of value is needed for backprop. | |||||
// forward_ctxt.outer_context.Enter() | |||||
// shape = array_ops.shape_internal(val, optimize=False) | |||||
// forward_ctxt.outer_context.Exit() | |||||
// # Save the shape to a stack. | |||||
// history_shape = outer_grad_state.AddForwardAccumulator(shape) | |||||
// # Get the shape back from the stack. | |||||
// outer_grad_ctxt = outer_grad_state.grad_context | |||||
// outer_grad_ctxt.Enter() | |||||
// real_shape = outer_grad_state.AddBackpropAccumulatedValue( | |||||
// history_shape, shape) | |||||
// result = array_ops.zeros(real_shape, val.dtype) | |||||
// outer_grad_ctxt.Exit() | |||||
// else: | |||||
// # This is not a nested loop. | |||||
// if val_shape.is_fully_defined(): | |||||
// # If the shape is known statically, just create a zero tensor | |||||
// # with the right shape. | |||||
// result = array_ops.zeros(val_shape.dims, val.dtype) | |||||
// else: | |||||
// result = array_ops.zeros_like(val, optimize=False) | |||||
// return result | |||||
// def ZerosLike(self, op, index): | |||||
// """Create zeros_like for the specified output of an op. | |||||
// If op is in a while loop that is part of gradients(), this method | |||||
// must be called in its grad loop context. | |||||
// Args: | |||||
// op: A tensorflow operation. | |||||
// index: the index for a specific output of the op. | |||||
// Returns: | |||||
// A zero tensor of the same shape of op.outputs[index]. | |||||
// """ | |||||
// if util.IsLoopSwitch(op): | |||||
// return None | |||||
// if op.graph._building_function: # pylint: disable=protected-access | |||||
// # The optimization here is tricky to apply to functions | |||||
// return array_ops.zeros_like(op.outputs[index]) | |||||
// dead_branch = util.IsSwitch(op) | |||||
// forward_ctxt = _GetWhileContext(op) | |||||
// grad_state = self._map.get(forward_ctxt) | |||||
// if grad_state is None: | |||||
// # op is not in a while loop that is part of gradients(). | |||||
// return ZerosLikeOutsideLoop(op, index) | |||||
// op_ctxt = op._get_control_flow_context() | |||||
// val = ops.convert_to_tensor(op.outputs[index], name="tensor") | |||||
// shape = val.get_shape() | |||||
// if shape.is_fully_defined(): | |||||
// # If the shape is known statically, just create a zero tensor with | |||||
// # the right shape in the grad loop context. | |||||
// result = constant_op.constant(0, shape=shape.dims, dtype=val.dtype) | |||||
// if dead_branch: | |||||
// # op is a cond switch. Guard the zero tensor with a switch. | |||||
// pred = grad_state.history_map.get(op_ctxt.pred.name) | |||||
// branch = op_ctxt.branch | |||||
// result = _SwitchRefOrTensor(result, pred)[1 - branch] | |||||
// else: | |||||
// # Unknown shape so keep a history of the shape at runtime. | |||||
// if dead_branch: | |||||
// # Need to add a special switch to guard the value. | |||||
// pred = op_ctxt.pred | |||||
// branch = op_ctxt.branch | |||||
// op_ctxt.outer_context.Enter() | |||||
// val = _SwitchRefOrTensor(op.inputs[0], pred)[1 - branch] | |||||
// zeros_shape = array_ops.shape_internal(val, optimize=False) | |||||
// op_ctxt.outer_context.Exit() | |||||
// val.op._set_control_flow_context(op_ctxt) | |||||
// zeros_shape.op._set_control_flow_context(op_ctxt) | |||||
// else: | |||||
// op_ctxt.Enter() | |||||
// zeros_shape = array_ops.shape_internal(val, optimize=False) | |||||
// op_ctxt.Exit() | |||||
// # Add forward accumulator for shape. | |||||
// grad_state.grad_context.Exit() | |||||
// history_zeros_shape = grad_state.AddForwardAccumulator( | |||||
// zeros_shape, dead_branch=dead_branch) | |||||
// grad_state.grad_context.Enter() | |||||
// # Create a zero tensor with the right shape. | |||||
// shape = grad_state.AddBackpropAccumulatedValue(history_zeros_shape, | |||||
// zeros_shape, dead_branch) | |||||
// result = array_ops.zeros(shape, val.dtype) | |||||
// return result | |||||
// def PostProcessing(self): | |||||
// """Perform postprocessing at the end of gradients(). | |||||
// We have created the gradient graph at this point. So this function | |||||
// can be used to perform any postprocessing on the gradient graph. | |||||
// We currently perform the following postprocessing: | |||||
// 1. Patch the gradient graph if the output of a loop variable | |||||
// doesn't depend on its input. | |||||
// """ | |||||
// for _, grad_state in self._map.items(): | |||||
// for _, b_merge in grad_state.switch_map.items(): | |||||
// if b_merge.op.inputs[0] == b_merge.op.inputs[1]: | |||||
// # The value of this loop variable at iteration i+1 doesn't | |||||
// # depend on its value at iteration i. So use zeros as the | |||||
// # gradients for all iterations > 0. | |||||
// dtype = b_merge.op.inputs[0].dtype | |||||
// shape = b_merge.op.inputs[0].get_shape() | |||||
// # pylint: disable=protected-access | |||||
// if shape.is_fully_defined(): | |||||
// grad_state.grad_context.Enter() | |||||
// # Create a zeros and use it for iterations > 0. | |||||
// grad_val = constant_op.constant(0, dtype=dtype, shape=shape) | |||||
// next_grad_val = _NextIteration(grad_val) | |||||
// grad_state.grad_context.Exit() | |||||
// else: | |||||
// # Create a zeros in the outer grad context. | |||||
// outer_grad_ctxt = grad_state.grad_context.outer_context | |||||
// if outer_grad_ctxt: | |||||
// outer_grad_ctxt.Enter() | |||||
// enter_grad_op = b_merge.op.inputs[0].op | |||||
// enter_grad = enter_grad_op.inputs[0] | |||||
// grad_shape = array_ops.shape_internal(enter_grad, optimize=False) | |||||
// grad_val = array_ops.zeros(grad_shape) | |||||
// if outer_grad_ctxt: | |||||
// outer_grad_ctxt.Exit() | |||||
// # Use the zeros for iterations > 0. | |||||
// grad_state.grad_context.Enter() | |||||
// next_grad_val = _NextIteration(grad_val) | |||||
// grad_state.grad_context.Exit() | |||||
// b_merge.op._update_input(1, next_grad_val) | |||||
// # pylint: enable=protected-access | |||||
} | |||||
} |
@@ -0,0 +1,398 @@ | |||||
using System; | |||||
using System.Collections; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Operations.ControlFlows | |||||
{ | |||||
public class GradLoopState | |||||
{ | |||||
//class GradLoopState(object): | |||||
// """The state used for constructing the gradient graph for a while loop. | |||||
// We create a GradLoopState for each while loop in forward and its | |||||
// corresponding while loop in backprop. This gives us access to both | |||||
// the forward and the backprop WhileContexts. | |||||
// During the construction of gradient graph, any time when we detect | |||||
// a forward value that is needed for backprop, we create a history | |||||
// accumulator and add it to `history_map`. Any time when we backprop | |||||
// a loop switch op (in _SwitchGrad), we add the grad merge op in | |||||
// `switch_map`. | |||||
// """ | |||||
// def __init__(self, forward_ctxt, outer_grad_state): | |||||
// # The grad loop state for the outer while loop. | |||||
// self._outer_grad_state = None | |||||
// # The while loop context for forward. | |||||
// self._forward_context = None | |||||
// # The loop counter added by AddForwardLoopCounter. It is the value | |||||
// # of the loop counter for the next iteration. | |||||
// self._forward_index = None | |||||
// # A sync op for forward. | |||||
// self._forward_sync = None | |||||
// # The while loop context for backprop. | |||||
private WhileContext _grad_context = null; | |||||
public WhileContext grad_context => _grad_context; | |||||
// # The loop counter added by AddBackpropLoopCounter. It is the value | |||||
// # of the loop counter for the current iteration. | |||||
// self._grad_index = None | |||||
// # A sync op for backprop. | |||||
// self._grad_sync = None | |||||
// # Information needed by backprop. | |||||
private Hashtable _history_map = new Hashtable(); | |||||
public Hashtable history_map => _history_map; | |||||
private Hashtable _switch_map = new Hashtable(); | |||||
public Hashtable switch_map => _switch_map; | |||||
// self._unused_exits = [] | |||||
// self._deferred_exits = [] | |||||
// self._forward_loop_exits = list(forward_ctxt.loop_exits) | |||||
// self._pending_exits_count = len(forward_ctxt.loop_exits) | |||||
// self._outer_grad_state = outer_grad_state | |||||
// if outer_grad_state: | |||||
// outer_forward_ctxt = outer_grad_state.forward_context | |||||
// else: | |||||
// if not hasattr(forward_ctxt, "outer_context"): | |||||
// raise ValueError("Failed to call gradients on a while loop without" | |||||
// "properly serializing graph via MetaGraphDef") | |||||
// outer_forward_ctxt = forward_ctxt.outer_context | |||||
// # Add the forward loop counter. | |||||
// with forward_ctxt._graph.as_default(): # pylint: disable=protected-access | |||||
// if outer_forward_ctxt: | |||||
// outer_forward_ctxt.Enter() | |||||
// cnt, forward_index = forward_ctxt.AddForwardLoopCounter(outer_grad_state) | |||||
// if outer_forward_ctxt: | |||||
// outer_forward_ctxt.Exit() | |||||
// self._forward_context = forward_ctxt | |||||
// self._forward_index = forward_index | |||||
// # Add the backprop WhileContext, and the backprop loop counter. | |||||
// if outer_grad_state: | |||||
// # This is a nested loop. Remember the iteration counts for each | |||||
// # execution of this inner loop. | |||||
// outer_forward_ctxt.AddName(cnt.name) | |||||
// history_cnt = outer_grad_state.AddForwardAccumulator(cnt) | |||||
// outer_grad_ctxt = outer_grad_state.grad_context | |||||
// outer_grad_ctxt.Enter() | |||||
// self._grad_context = WhileContext( | |||||
// maximum_iterations=forward_ctxt.maximum_iterations, | |||||
// parallel_iterations=forward_ctxt.parallel_iterations, | |||||
// back_prop=forward_ctxt.back_prop, | |||||
// swap_memory=forward_ctxt.swap_memory, | |||||
// name=forward_ctxt.name, | |||||
// grad_state=self) | |||||
// real_cnt = outer_grad_state.AddBackpropAccumulatedValue(history_cnt, cnt) | |||||
// self._grad_index = self._grad_context.AddBackpropLoopCounter( | |||||
// real_cnt, outer_grad_state) | |||||
// outer_grad_ctxt.Exit() | |||||
// else: | |||||
// if outer_forward_ctxt: | |||||
// outer_forward_ctxt.Enter() | |||||
// self._grad_context = WhileContext( | |||||
// maximum_iterations=forward_ctxt.maximum_iterations, | |||||
// parallel_iterations=forward_ctxt.parallel_iterations, | |||||
// back_prop=forward_ctxt.back_prop, | |||||
// swap_memory=forward_ctxt.swap_memory, | |||||
// name=forward_ctxt.name, | |||||
// grad_state=self) | |||||
// self._grad_index = self._grad_context.AddBackpropLoopCounter( | |||||
// cnt, outer_grad_state) | |||||
// if outer_forward_ctxt: | |||||
// outer_forward_ctxt.Exit() | |||||
// @property | |||||
// def outer_grad_state(self): | |||||
// """The grad loop state for outer loop.""" | |||||
// return self._outer_grad_state | |||||
// @property | |||||
// def forward_context(self): | |||||
// """The while loop context for forward.""" | |||||
// return self._forward_context | |||||
// @property | |||||
// def forward_index(self): | |||||
// """The loop index of forward loop.""" | |||||
// return self._forward_index | |||||
// @property | |||||
// def forward_sync(self): | |||||
// """A control trigger node for synchronization in the forward loop. | |||||
// One main use is to keep the push ops of a stack executed in the | |||||
// iteration order. | |||||
// """ | |||||
// if self._forward_sync is None: | |||||
// with ops.control_dependencies(None): | |||||
// self._forward_sync = control_trigger(name="f_sync") | |||||
// self._forward_sync._set_control_flow_context(self._forward_context) | |||||
// self._forward_index.op._add_control_input(self._forward_sync) | |||||
// return self._forward_sync | |||||
// @property | |||||
// def grad_context(self): | |||||
// """The corresponding WhileContext for gradient.""" | |||||
// return self._grad_context | |||||
// @property | |||||
// def grad_index(self): | |||||
// """The loop index of backprop loop.""" | |||||
// return self._grad_index | |||||
// @property | |||||
// def grad_sync(self): | |||||
// """A control trigger node for synchronization in the grad loop. | |||||
// One main use is to keep the pop ops of a stack executed in the | |||||
// iteration order. | |||||
// """ | |||||
// if self._grad_sync is None: | |||||
// with ops.control_dependencies(None): | |||||
// self._grad_sync = control_trigger(name="b_sync") | |||||
// self._grad_sync._set_control_flow_context(self._grad_context) | |||||
// self._grad_index.op._add_control_input(self._grad_sync) | |||||
// if self._grad_context.outer_context: | |||||
// self._grad_context.outer_context.AddInnerOp(self._grad_sync) | |||||
// return self._grad_sync | |||||
// @property | |||||
// def history_map(self): | |||||
// """The map that records all the tensors needed for backprop.""" | |||||
// return self._history_map | |||||
// @property | |||||
// def switch_map(self): | |||||
// """The map that records all the Switch ops for the while loop.""" | |||||
// return self._switch_map | |||||
// @property | |||||
// def unused_exits(self): | |||||
// """The list of "unused" exits.""" | |||||
// return self._unused_exits | |||||
// @property | |||||
// def deferred_exits(self): | |||||
// """The list of "deferred" exits.""" | |||||
// return self._deferred_exits | |||||
// @property | |||||
// def forward_loop_exits(self): | |||||
// """The list of exits of the forward loop.""" | |||||
// return self._forward_loop_exits | |||||
// @property | |||||
// def pending_exits_count(self): | |||||
// """The number of exits we expect to see but haven't.""" | |||||
// return self._pending_exits_count | |||||
// @pending_exits_count.setter | |||||
// def pending_exits_count(self, cnt): | |||||
// """Set the pending count to cnt.""" | |||||
// self._pending_exits_count = cnt | |||||
/// <summary> | |||||
/// Add an accumulator for each forward tensor that is needed in backprop. | |||||
/// | |||||
/// This is added to the forward loop at the first time when a tensor | |||||
/// in the forward loop is used by backprop gradient computation loop. | |||||
/// We create an accumulator that accumulates the value of tensor at each | |||||
/// iteration. Called in the control flow context where gradients() is called. | |||||
/// | |||||
/// The pseudocode is: | |||||
/// ``` | |||||
/// acc = stack(); | |||||
/// while (_pivot) { | |||||
/// acc = stack_push(acc, value); | |||||
/// } | |||||
/// ``` | |||||
/// | |||||
/// We make sure that the stack push op in one iteration is executed before | |||||
/// next iteration. This is achieved by adding a control edge from | |||||
/// `forward_index.op.inputs[0].op` to the push op, and another control | |||||
/// edge from the push op to either `forward_index.op` or `forward_sync`. | |||||
/// </summary> | |||||
/// <param name="value"> The source tensor in forward that is to be accumulated.</param> | |||||
/// <param name="dead_branch"> True iff the tensor is on a dead branch of a cond.</param> | |||||
/// <returns>The stack that contains the accumulated history of the tensor.</returns> | |||||
public Tensor AddForwardAccumulator(Tensor value, bool dead_branch = false) | |||||
{ | |||||
throw new NotImplementedException("AddForwardAccumulator"); | |||||
// # curr_ctxt is the context that tf.gradients was called in. | |||||
// with self._forward_index.graph.as_default(): | |||||
// curr_ctxt = ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access | |||||
// with ops.control_dependencies(None): | |||||
// if curr_ctxt: | |||||
// curr_ctxt.Enter() | |||||
// with ops.colocate_with(value): | |||||
// # We only need to pass maximum_iterations to the stack if | |||||
// # we're inside an XLA context. | |||||
// if not util.IsInXLAContext(value.op): | |||||
// max_size = constant_op.constant(-1, dtypes.int32) | |||||
// else: | |||||
// max_size = GetMaxSizeFromNestedMaximumIterations( | |||||
// value, self.forward_context) | |||||
// acc = gen_data_flow_ops.stack_v2( | |||||
// max_size=max_size, elem_type=value.dtype.base_dtype, name="f_acc") | |||||
// if curr_ctxt: | |||||
// curr_ctxt.Exit() | |||||
// # Make acc available in the forward context. | |||||
// enter_acc = self.forward_context.AddValue(acc) | |||||
// # Add the stack_push op in the context of value.op. | |||||
// swap_enabled = self.forward_context.swap_memory | |||||
// value_ctxt = util.GetOutputContext(value.op) | |||||
// if value_ctxt == self.forward_context: | |||||
// # value is not nested in the forward context. | |||||
// self.forward_context.Enter() | |||||
// push = gen_data_flow_ops.stack_push_v2( | |||||
// enter_acc, value, swap_memory=swap_enabled) | |||||
// self.forward_context.Exit() | |||||
// # Protect stack push and order it before forward_index. | |||||
// self.forward_index.op._add_control_input(push.op) | |||||
// else: | |||||
// # value is in a cond context within the forward context. | |||||
// if not isinstance(value_ctxt, CondContext): | |||||
// raise TypeError("value_ctxt is not a CondContext: %s" % value_ctxt) | |||||
// if dead_branch: | |||||
// # The special case for creating a zero tensor for a dead | |||||
// # branch of a switch. See ControlFlowState.ZerosLike(). | |||||
// value_ctxt.outer_context.Enter() | |||||
// push = gen_data_flow_ops.stack_push_v2( | |||||
// enter_acc, value, swap_memory=swap_enabled) | |||||
// value_ctxt.outer_context.Exit() | |||||
// push.op._set_control_flow_context(value_ctxt) | |||||
// else: | |||||
// value_ctxt.Enter() | |||||
// push = gen_data_flow_ops.stack_push_v2( | |||||
// enter_acc, value, swap_memory=swap_enabled) | |||||
// value_ctxt.Exit() | |||||
// # Protect stack push and order it before forward_sync. | |||||
// self.forward_sync._add_control_input(push.op) | |||||
// # Order stack push after the successor of forward_index | |||||
// add_op = self.forward_index.op.inputs[0].op | |||||
// push.op._add_control_input(add_op) | |||||
// return acc | |||||
} | |||||
// """Add the getter for an accumulated value in the grad context. | |||||
// | |||||
// This is added to the backprop loop. Called in the grad context to | |||||
// get the value of an accumulated value. The stack pop op must be guarded | |||||
// by the pred of the controlling cond. | |||||
// | |||||
// Args: | |||||
// history_value: The history (a stack) of a value. | |||||
// value: The value that is pushed onto the stack. | |||||
// dead_branch: True iff the tensor is on a dead branch of a cond. | |||||
// | |||||
// Returns: | |||||
// The current value (the top of the stack). | |||||
// """ | |||||
public Tensor AddBackpropAccumulatedValue(Tensor history_value, Tensor value, bool dead_branch= false) | |||||
{ | |||||
throw new NotImplementedException(); | |||||
// history_ctxt = history_value.op._get_control_flow_context() | |||||
// # Find the cond context that controls history_value if any. | |||||
// cond_ctxt = None | |||||
// value_ctxt = value.op._get_control_flow_context() | |||||
// while value_ctxt and value_ctxt != history_ctxt: | |||||
// if isinstance(value_ctxt, CondContext): | |||||
// cond_ctxt = value_ctxt | |||||
// break | |||||
// value_ctxt = value_ctxt.outer_context | |||||
// with ops.control_dependencies(None): | |||||
// self.grad_context.Enter() | |||||
// if cond_ctxt: | |||||
// # Guard stack pop with a switch if it is controlled by a cond. | |||||
// grad_state = self | |||||
// pred = None | |||||
// while pred is None and grad_state: | |||||
// pred = grad_state.history_map.get(cond_ctxt.pred.name) | |||||
// grad_state = grad_state.outer_grad_state | |||||
// if pred is None: | |||||
// pred = cond_ctxt.pred | |||||
// branch = (1 - cond_ctxt.branch) if dead_branch else cond_ctxt.branch | |||||
// history_value = _SwitchRefOrTensor(history_value, pred)[branch] | |||||
// pop = gen_data_flow_ops.stack_pop_v2(history_value, | |||||
// value.dtype.base_dtype) | |||||
// pop.set_shape(value.get_shape()) | |||||
// self.grad_context.Exit() | |||||
// parallel_iterations = self.grad_context.parallel_iterations | |||||
// if parallel_iterations > 1: | |||||
// # All pops are ordered after pivot_for_body and before grad_sync. | |||||
// self.grad_sync._add_control_input(pop.op) | |||||
// return pop | |||||
} | |||||
// def GetRealValue(self, value): | |||||
// """Get the real value of `value`. | |||||
// If backprop "uses" a value produced by forward inference, an accumulator | |||||
// is added in the forward loop to accumulate its values. We use the | |||||
// accumulated value. This method must be called in the grad loop context. | |||||
// `value` must be in forward and needed for backprop. | |||||
// Args: | |||||
// value: A tensor to be captured. | |||||
// Returns: | |||||
// The same tensor obtained from the saved history. | |||||
// """ | |||||
// assert value.op.type not in ["Variable", "VariableV2"] | |||||
// real_value = self._history_map.get(value.name) | |||||
// if real_value is None: | |||||
// cur_value = value | |||||
// cur_grad_state = self | |||||
// while True: | |||||
// enter_op = util.GetLoopConstantEnter(cur_value) | |||||
// if enter_op: | |||||
// # Special case: cur_value comes from a constant Enter node. | |||||
// cur_value = enter_op.inputs[0] | |||||
// cur_grad_state = cur_grad_state.outer_grad_state | |||||
// if cur_grad_state is None: | |||||
// # We are now outside all nested loops for this gradient(), | |||||
// # so `value` is a loop invariant and there is no need to | |||||
// # save the history of value. Just make cur_value to enter | |||||
// # the right control flow context. | |||||
// real_value = self._grad_context.AddValue(cur_value) | |||||
// break | |||||
// elif constant_op.is_constant(cur_value): | |||||
// # If the value to be forwarded is a constant, clone the constant in | |||||
// # the gradient loop rather than using a stack. | |||||
// # TODO(phawkins): consider hoisting the constant out of the loop | |||||
// # instead. | |||||
// real_value = constant_op.constant( | |||||
// tensor_util.constant_value(cur_value), dtype=cur_value.dtype) | |||||
// break | |||||
// else: | |||||
// # Record the history of this value in forward_ctxt. | |||||
// self._grad_context.Exit() | |||||
// history_value = cur_grad_state.AddForwardAccumulator(cur_value) | |||||
// self._grad_context.Enter() | |||||
// break | |||||
// if real_value is None: | |||||
// # Add the stack pop op in the grad context. | |||||
// real_value = cur_grad_state.AddBackpropAccumulatedValue( | |||||
// history_value, cur_value) | |||||
// if cur_grad_state != self: | |||||
// real_value = self._grad_context.AddValue(real_value) | |||||
// self._history_map[value.name] = real_value | |||||
// return real_value | |||||
} | |||||
} |
@@ -4,13 +4,15 @@ using System.Text; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
public interface IControlFlowContext | |||||
{ | |||||
void AddOp(Operation op); | |||||
IControlFlowContext outer_context { get; } | |||||
HashSet<string> values { get; } | |||||
Tensor AddValue(Tensor val); | |||||
void AddInnerOp(Operation resultOp); | |||||
object to_proto(); | |||||
} | |||||
// henon: this was too much trouble. there is no value just cost to use an interface here. | |||||
//public interface IControlFlowContext | |||||
//{ | |||||
// void AddOp(Operation op); | |||||
// IControlFlowContext outer_context { get; } | |||||
// HashSet<string> values { get; } | |||||
// Tensor pivot { get; } | |||||
// Tensor AddValue(Tensor val); | |||||
// void AddInnerOp(Operation resultOp); | |||||
// object to_proto(); | |||||
//} | |||||
} | } |
@@ -1,11 +1,26 @@ | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Text; | using System.Text; | ||||
using Tensorflow.Operations.ControlFlows; | |||||
namespace Tensorflow.Operations | namespace Tensorflow.Operations | ||||
{ | { | ||||
public class WhileContext : ControlFlowContext | public class WhileContext : ControlFlowContext | ||||
{ | { | ||||
private bool _back_prop=true; | |||||
private GradLoopState _grad_state =null; | |||||
public override WhileContext GetWhileContext() | |||||
{ | |||||
return this; | |||||
} | |||||
public override GradLoopState grad_state => _grad_state; | |||||
public override bool back_prop => _back_prop; | |||||
public static WhileContext from_proto(object proto) | public static WhileContext from_proto(object proto) | ||||
{ | { | ||||
throw new NotImplementedException(); | throw new NotImplementedException(); | ||||
@@ -7,7 +7,7 @@ namespace Tensorflow | |||||
{ | { | ||||
public partial class Operation | public partial class Operation | ||||
{ | { | ||||
private IControlFlowContext _control_flow_context; | |||||
private ControlFlowContext _control_flow_context; | |||||
/// <summary> | /// <summary> | ||||
/// Add this op to its control flow context. | /// Add this op to its control flow context. | ||||
@@ -39,12 +39,12 @@ namespace Tensorflow | |||||
_add_control_input(op); | _add_control_input(op); | ||||
} | } | ||||
public void _set_control_flow_context(IControlFlowContext ctx) | |||||
public void _set_control_flow_context(ControlFlowContext ctx) | |||||
{ | { | ||||
_control_flow_context = ctx; | _control_flow_context = ctx; | ||||
} | } | ||||
public IControlFlowContext _get_control_flow_context() | |||||
public ControlFlowContext _get_control_flow_context() | |||||
{ | { | ||||
return _control_flow_context; | return _control_flow_context; | ||||
} | } | ||||
@@ -3,6 +3,7 @@ using System.Collections.Generic; | |||||
using System.Linq; | using System.Linq; | ||||
using System.Text; | using System.Text; | ||||
using Tensorflow.Operations; | using Tensorflow.Operations; | ||||
using Tensorflow.Operations.ControlFlows; | |||||
using util = Tensorflow.control_flow_util; | using util = Tensorflow.control_flow_util; | ||||
namespace Tensorflow | namespace Tensorflow | ||||
@@ -93,9 +94,9 @@ namespace Tensorflow | |||||
/// <param name="between_op_list"></param> | /// <param name="between_op_list"></param> | ||||
/// <param name="between_ops"></param> | /// <param name="between_ops"></param> | ||||
/// <param name="colocate_gradients_with_ops"></param> | /// <param name="colocate_gradients_with_ops"></param> | ||||
public static object MaybeCreateControlFlowState(List<Operation> between_op_list, List<Operation> between_ops, bool colocate_gradients_with_ops) | |||||
public static ControlFlowState MaybeCreateControlFlowState(List<Operation> between_op_list, List<Operation> between_ops, bool colocate_gradients_with_ops) | |||||
{ | { | ||||
object loop_state = null; | |||||
ControlFlowState loop_state = null; | |||||
foreach (var op in between_op_list) | foreach (var op in between_op_list) | ||||
{ | { | ||||
@@ -103,7 +104,7 @@ namespace Tensorflow | |||||
{ | { | ||||
if(loop_state == null) | if(loop_state == null) | ||||
{ | { | ||||
// loop_state = ControlFlowState(); | |||||
loop_state = new ControlFlowState(); | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -207,7 +208,7 @@ namespace Tensorflow | |||||
/// `(output_false, output_true)`: If `pred` is true, data will be forwarded to | /// `(output_false, output_true)`: If `pred` is true, data will be forwarded to | ||||
/// `output_true`, otherwise it goes to `output_false`. | /// `output_true`, otherwise it goes to `output_false`. | ||||
/// </returns> | /// </returns> | ||||
public static (Tensor, Tensor) _SwitchRefOrTensor(Tensor data, Tensor pred, string name = "Switch") | |||||
public static Tensor[] _SwitchRefOrTensor(Tensor data, Tensor pred, string name = "Switch") | |||||
{ | { | ||||
data = ops.convert_to_tensor_or_indexed_slices(data, name: "data"); | data = ops.convert_to_tensor_or_indexed_slices(data, name: "data"); | ||||
// NOTE(vrv): ops.colocate_with(data, ignore_existing=True) below | // NOTE(vrv): ops.colocate_with(data, ignore_existing=True) below | ||||
@@ -298,7 +299,9 @@ namespace Tensorflow | |||||
*/ | */ | ||||
// Add the Switch to the graph. | // Add the Switch to the graph. | ||||
var (p_2, p_1) = @switch(pred, pred); | |||||
var switch_result= @switch(pred, pred); | |||||
var p_2=switch_result[0]; | |||||
var p_1 = switch_result[1]; | |||||
var pivot_1 = array_ops.identity(p_1, name: "switch_t"); | var pivot_1 = array_ops.identity(p_1, name: "switch_t"); | ||||
var pivot_2 = array_ops.identity(p_2, name: "switch_f"); | var pivot_2 = array_ops.identity(p_2, name: "switch_f"); | ||||
pred = array_ops.identity(pred, name: "pred_id"); | pred = array_ops.identity(pred, name: "pred_id"); | ||||
@@ -379,7 +382,9 @@ namespace Tensorflow | |||||
return with(ops.name_scope(name, "cond", new { pred }), delegate | return with(ops.name_scope(name, "cond", new { pred }), delegate | ||||
{ | { | ||||
// Add the Switch to the graph. | // Add the Switch to the graph. | ||||
var (p_2, p_1) = @switch(pred, pred); | |||||
var switch_result = @switch(pred, pred); | |||||
var p_2 = switch_result[0]; | |||||
var p_1 = switch_result[1]; | |||||
var pivot_1 = array_ops.identity(p_1, name: "switch_t"); | var pivot_1 = array_ops.identity(p_1, name: "switch_t"); | ||||
var pivot_2 = array_ops.identity(p_2, name: "switch_f"); | var pivot_2 = array_ops.identity(p_2, name: "switch_f"); | ||||
pred = array_ops.identity(pred, name: "pred_id"); | pred = array_ops.identity(pred, name: "pred_id"); | ||||
@@ -460,7 +465,7 @@ namespace Tensorflow | |||||
/// <param name="pred"></param> | /// <param name="pred"></param> | ||||
/// <param name="dtype"></param> | /// <param name="dtype"></param> | ||||
/// <param name="name"></param> | /// <param name="name"></param> | ||||
public static (Tensor, Tensor) @switch(Tensor data, | |||||
public static Tensor[] @switch(Tensor data, | |||||
Tensor pred, | Tensor pred, | ||||
TF_DataType dtype = TF_DataType.DtInvalid, | TF_DataType dtype = TF_DataType.DtInvalid, | ||||
string name = null) | string name = null) | ||||
@@ -30,7 +30,7 @@ namespace Tensorflow | |||||
/// <summary> | /// <summary> | ||||
/// Return the control flow context for the output of an op. | /// Return the control flow context for the output of an op. | ||||
/// </summary> | /// </summary> | ||||
public static IControlFlowContext GetOutputContext(Operation op) | |||||
public static ControlFlowContext GetOutputContext(Operation op) | |||||
{ | { | ||||
var ctxt = op._get_control_flow_context(); | var ctxt = op._get_control_flow_context(); | ||||
// Exit nodes usually have a control flow context, except in the case where the | // Exit nodes usually have a control flow context, except in the case where the | ||||
@@ -33,14 +33,14 @@ namespace Tensorflow | |||||
/// output_false: A `Tensor`. Has the same type as `data`. | /// output_false: A `Tensor`. Has the same type as `data`. | ||||
/// output_true: A `Tensor`. Has the same type as `data`. | /// output_true: A `Tensor`. Has the same type as `data`. | ||||
/// </returns> | /// </returns> | ||||
public static (Tensor, Tensor) @switch(Tensor data, Tensor pred, string name = null) | |||||
public static Tensor[] @switch(Tensor data, Tensor pred, string name = null) | |||||
{ | { | ||||
var _op = _op_def_lib._apply_op_helper("Switch", name, new { data, pred }); | var _op = _op_def_lib._apply_op_helper("Switch", name, new { data, pred }); | ||||
var _inputs_flat = _op.inputs; | var _inputs_flat = _op.inputs; | ||||
var _attrs = ("T", _op.get_attr("T")); | var _attrs = ("T", _op.get_attr("T")); | ||||
// TODO: missing original code | // TODO: missing original code | ||||
//_execute.record_gradient("Switch", _inputs_flat, _attrs, _result, name); | //_execute.record_gradient("Switch", _inputs_flat, _attrs, _result, name); | ||||
return (_op.outputs[0], _op.outputs[1]); | |||||
return new []{_op.outputs[0], _op.outputs[1]}; | |||||
} | } | ||||
public static (Tensor, Tensor) merge(Tensor[] inputs, string name = null) | public static (Tensor, Tensor) merge(Tensor[] inputs, string name = null) | ||||