From 59cbca5c1703ed5fe607ffa11cb8e54edbe9fdbb Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 3 Nov 2019 10:41:10 -0600 Subject: [PATCH] ControlFlowState.PostProcessing --- .../ControlFlows/ControlFlowState.cs | 30 ++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowState.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowState.cs index 1d296774..9351cab4 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowState.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowState.cs @@ -290,7 +290,35 @@ namespace Tensorflow.Operations.ControlFlows public void PostProcessing() { - throw new NotImplementedException("PostProcessing"); + foreach(var grad_state in _map.Values) + { + foreach(var b_merge in grad_state.switch_map.Values) + { + if(b_merge.op.inputs[0] == b_merge.op.inputs[1]) + { + Tensor next_grad_val = null; + // The value of this loop variable at iteration i+1 doesn't + // depend on its value at iteration i. So use zeros as the + // gradients for all iterations > 0. + var dtype = b_merge.op.inputs[0].dtype; + var shape = b_merge.op.inputs[0].TensorShape; + if (shape.is_fully_defined()) + { + grad_state.grad_context.Enter(); + // Create a zeros and use it for iterations > 0. + var grad_val = constant_op.constant(0, dtype: dtype, shape: shape); + next_grad_val = control_flow_ops._NextIteration(grad_val); + grad_state.grad_context.Exit(); + } + else + { + throw new NotImplementedException("PostProcessing shape is not fully defined."); + } + + b_merge.op._update_input(1, next_grad_val); + } + } + } } } }