using System; using System.Collections.Generic; using System.Linq; using System.Text; using Tensorflow.Graphs; using static Tensorflow.Binding; using static Tensorflow.tensorflow; namespace Tensorflow.Functions { /// /// Caches forward and backward functions compatible with eager gradients. /// public abstract class TapeGradientFunctions { string FORWARD_FUNCTION_ATTRIBUTE_NAME = "forward_function_name"; string BACKWARD_FUNCTION_ATTRIBUTE_NAME = "backward_function_name"; string _FORWARD_PREFIX = "__forward_"; string _BACKWARD_PREFIX = "__backward_"; string _INFERENCE_PREFIX = "__inference_"; protected FuncGraph _func_graph; protected EagerDefinedFunction _forward; protected FuncGraph _forward_graph; protected List _forwardprop_output_indices; protected int _num_forwardprop_outputs; protected ConcreteFunction _backward; public TapeGradientFunctions(FuncGraph func_graph, bool need_gradients_for_jvps) { _func_graph = func_graph; } public EagerDefinedFunction Forward(Tensors inference_args) { return ForwardAndBackwardFunctions(inference_args); } /// /// Record the function call operation. /// /// /// public void Record(Tensors flat_outputs, Tensors inference_args) { var (backward_function, to_record) = _wrap_backward_function(_forward_graph, _backward, flat_outputs); tf.Runner.RecordGradient(_forward.Name, inference_args, new object[0], to_record, getBackwardFunction: () => backward_function); } /// /// Create a backward function given `outputs` from the forward function. /// /// /// /// /// (BackwardFunction, Tensors) _wrap_backward_function(FuncGraph forward_graph, ConcreteFunction backward, Tensors outputs) { var capture_mapping = new Dictionary(); foreach(var (i, output) in enumerate(outputs)) capture_mapping[forward_graph.Outputs[i].Id] = output; var remapped_captures = new Tensors(); foreach(var capture in backward.CapturedInputs) { if (capture_mapping.ContainsKey(capture.Id)) remapped_captures.Add(capture_mapping[capture.Id]); } var backward_function_inputs = backward.Inputs.Length - backward.CapturedInputs.Length; var recorded_outputs = new Tensors(); var relevant_outputs = outputs; var trainable_recorded_outputs = 0; var skip_positions = new List(); foreach (var (output_index, output) in enumerate(relevant_outputs)) { if (trainable_recorded_outputs < backward_function_inputs) recorded_outputs.Add(output); if (gradients_util.IsTrainable(output)) trainable_recorded_outputs += 1; else skip_positions.Add(output_index); } BackwardFunction _backward_function_wrapper = (args, unneeded_gradients) => { var processed_args = new Tensors(); var input_index = 0; foreach (var (output_index, arg) in enumerate(args)) { if (skip_positions.Contains(output_index)) continue; if (arg == null) throw new NotImplementedException(""); processed_args.Add(arg); input_index += 1; if (input_index >= backward_function_inputs) break; } tf.Logger.Debug($"Invoke backward function: {backward.Name}"); var gradients = backward.CallFlat(processed_args, remapped_captures); foreach (var unneeded_gradient_index in unneeded_gradients) { var index = Convert.ToInt32(unneeded_gradient_index); if (gradients.Length <= index) gradients.Insert(index, null); } return gradients; }; return (_backward_function_wrapper, recorded_outputs); } protected (EagerDefinedFunction, FuncGraph, ConcreteFunction, List, int) BuildFunctionsForOutputs(Tensors outputs, Tensors inference_args) { var trainable_outputs = new List(); var trainable_indices = new List(); foreach(var (index, output) in enumerate(outputs)) { if (gradients_util.IsTrainable(output)) { trainable_outputs.Add(output); trainable_indices.Add(index); } } var gradients_wrt_outputs = new List(); var backwards_graph = new FuncGraph($"{_BACKWARD_PREFIX}_{_func_graph.FuncName}_{ops.uid()}"); backwards_graph.as_default(); foreach (var output in trainable_outputs) gradients_wrt_outputs.Add(tf.placeholder(output.dtype, output.shape)); var gradients_wrt_inputs = gradients_util._GradientsHelper(trainable_outputs.ToArray(), _func_graph.Inputs, grad_ys: gradients_wrt_outputs.ToArray(), src_graph: _func_graph); var captures_from_forward = backwards_graph.external_captures .Where(x => x.IsCreatedInGraphMode && x.graph == _func_graph) .ToArray(); foreach(var capture in captures_from_forward) { if (!_func_graph.Outputs.Contains(capture)) _func_graph.Outputs.Add(capture); } backwards_graph.Exit(); var forward_function_name = $"{_FORWARD_PREFIX}_{_func_graph.FuncName}_{ops.uid()}"; var backward_function_attr = new Dictionary(); backward_function_attr[FORWARD_FUNCTION_ATTRIBUTE_NAME] = forward_function_name; gradients_wrt_outputs.append(backwards_graph.internal_captures); backwards_graph.Inputs = gradients_wrt_outputs; backwards_graph.Outputs = gradients_wrt_inputs; var backward_function = new ConcreteFunction(backwards_graph, backward_function_attr); var forward_function_attr = new Dictionary(); forward_function_attr[BACKWARD_FUNCTION_ATTRIBUTE_NAME] = backward_function.Name; var forward_function = new EagerDefinedFunction(forward_function_name, _func_graph, _func_graph.Inputs, _func_graph.Outputs, forward_function_attr); return (forward_function, _func_graph, backward_function, null, 0); } public virtual EagerDefinedFunction ForwardAndBackwardFunctions(Tensors inference_args) { throw new NotImplementedException(""); } } }