|
|
@@ -290,7 +290,35 @@ namespace Tensorflow.Operations.ControlFlows |
|
|
|
|
|
|
|
public void PostProcessing()
|
|
|
|
{
|
|
|
|
throw new NotImplementedException("PostProcessing");
|
|
|
|
foreach(var grad_state in _map.Values)
|
|
|
|
{
|
|
|
|
foreach(var b_merge in grad_state.switch_map.Values)
|
|
|
|
{
|
|
|
|
if(b_merge.op.inputs[0] == b_merge.op.inputs[1])
|
|
|
|
{
|
|
|
|
Tensor next_grad_val = null;
|
|
|
|
// The value of this loop variable at iteration i+1 doesn't
|
|
|
|
// depend on its value at iteration i. So use zeros as the
|
|
|
|
// gradients for all iterations > 0.
|
|
|
|
var dtype = b_merge.op.inputs[0].dtype;
|
|
|
|
var shape = b_merge.op.inputs[0].TensorShape;
|
|
|
|
if (shape.is_fully_defined())
|
|
|
|
{
|
|
|
|
grad_state.grad_context.Enter();
|
|
|
|
// Create a zeros and use it for iterations > 0.
|
|
|
|
var grad_val = constant_op.constant(0, dtype: dtype, shape: shape);
|
|
|
|
next_grad_val = control_flow_ops._NextIteration(grad_val);
|
|
|
|
grad_state.grad_context.Exit();
|
|
|
|
}
|
|
|
|
else
|
|
|
|
{
|
|
|
|
throw new NotImplementedException("PostProcessing shape is not fully defined.");
|
|
|
|
}
|
|
|
|
|
|
|
|
b_merge.op._update_input(1, next_grad_val);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|