Browse Source

ControlFlowState.PostProcessing

tags/v0.12
Oceania2018 6 years ago
parent
commit
59cbca5c17
1 changed files with 29 additions and 1 deletions
  1. +29
    -1
      src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowState.cs

+ 29
- 1
src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowState.cs View File

@@ -290,7 +290,35 @@ namespace Tensorflow.Operations.ControlFlows
public void PostProcessing()
{
throw new NotImplementedException("PostProcessing");
foreach(var grad_state in _map.Values)
{
foreach(var b_merge in grad_state.switch_map.Values)
{
if(b_merge.op.inputs[0] == b_merge.op.inputs[1])
{
Tensor next_grad_val = null;
// The value of this loop variable at iteration i+1 doesn't
// depend on its value at iteration i. So use zeros as the
// gradients for all iterations > 0.
var dtype = b_merge.op.inputs[0].dtype;
var shape = b_merge.op.inputs[0].TensorShape;
if (shape.is_fully_defined())
{
grad_state.grad_context.Enter();
// Create a zeros and use it for iterations > 0.
var grad_val = constant_op.constant(0, dtype: dtype, shape: shape);
next_grad_val = control_flow_ops._NextIteration(grad_val);
grad_state.grad_context.Exit();
}
else
{
throw new NotImplementedException("PostProcessing shape is not fully defined.");
}
b_merge.op._update_input(1, next_grad_val);
}
}
}
}
}
}

Loading…
Cancel
Save