You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

TapeGradientFunctions.cs 7.5 kB

4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using System.Text;
  5. using Tensorflow.Graphs;
  6. using static Tensorflow.Binding;
  7. using static Tensorflow.tensorflow;
  8. namespace Tensorflow.Functions
  9. {
  10. /// <summary>
  11. /// Caches forward and backward functions compatible with eager gradients.
  12. /// </summary>
  13. public abstract class TapeGradientFunctions
  14. {
  15. string FORWARD_FUNCTION_ATTRIBUTE_NAME = "forward_function_name";
  16. string BACKWARD_FUNCTION_ATTRIBUTE_NAME = "backward_function_name";
  17. string _FORWARD_PREFIX = "__forward_";
  18. string _BACKWARD_PREFIX = "__backward_";
  19. string _INFERENCE_PREFIX = "__inference_";
  20. protected FuncGraph _func_graph;
  21. protected EagerDefinedFunction _forward;
  22. protected FuncGraph _forward_graph;
  23. protected List<int> _forwardprop_output_indices;
  24. protected int _num_forwardprop_outputs;
  25. protected ConcreteFunction _backward;
  26. public TapeGradientFunctions(FuncGraph func_graph,
  27. bool need_gradients_for_jvps)
  28. {
  29. _func_graph = func_graph;
  30. }
  31. public EagerDefinedFunction Forward(Tensors inference_args)
  32. {
  33. return ForwardAndBackwardFunctions(inference_args);
  34. }
  35. /// <summary>
  36. /// Record the function call operation.
  37. /// </summary>
  38. /// <param name="flat_outputs"></param>
  39. /// <param name="inference_args"></param>
  40. public void Record(Tensors flat_outputs, Tensors inference_args)
  41. {
  42. var (backward_function, to_record) = _wrap_backward_function(_forward_graph, _backward, flat_outputs);
  43. tf.Runner.RecordGradient(_forward.Name, inference_args, new object[0], to_record,
  44. getBackwardFunction: () => backward_function);
  45. }
  46. /// <summary>
  47. /// Create a backward function given `outputs` from the forward function.
  48. /// </summary>
  49. /// <param name="forward_graph"></param>
  50. /// <param name="backward"></param>
  51. /// <param name="outputs"></param>
  52. /// <returns></returns>
  53. (BackwardFunction, Tensors) _wrap_backward_function(FuncGraph forward_graph, ConcreteFunction backward, Tensors outputs)
  54. {
  55. var capture_mapping = new Dictionary<long, Tensor>();
  56. foreach(var (i, output) in enumerate(outputs))
  57. capture_mapping[forward_graph.Outputs[i].Id] = output;
  58. var remapped_captures = new Tensors();
  59. foreach(var capture in backward.CapturedInputs)
  60. {
  61. if (capture_mapping.ContainsKey(capture.Id))
  62. remapped_captures.Add(capture_mapping[capture.Id]);
  63. }
  64. var backward_function_inputs = backward.Inputs.Length - backward.CapturedInputs.Length;
  65. var recorded_outputs = new Tensors();
  66. var relevant_outputs = outputs;
  67. var trainable_recorded_outputs = 0;
  68. var skip_positions = new List<int>();
  69. foreach (var (output_index, output) in enumerate(relevant_outputs))
  70. {
  71. if (trainable_recorded_outputs < backward_function_inputs)
  72. recorded_outputs.Add(output);
  73. if (gradients_util.IsTrainable(output))
  74. trainable_recorded_outputs += 1;
  75. else
  76. skip_positions.Add(output_index);
  77. }
  78. BackwardFunction _backward_function_wrapper = (args, unneeded_gradients) =>
  79. {
  80. var processed_args = new Tensors();
  81. var input_index = 0;
  82. foreach (var (output_index, arg) in enumerate(args))
  83. {
  84. if (skip_positions.Contains(output_index))
  85. continue;
  86. if (arg == null)
  87. throw new NotImplementedException("");
  88. processed_args.Add(arg);
  89. input_index += 1;
  90. if (input_index >= backward_function_inputs)
  91. break;
  92. }
  93. tf.Logger.Debug($"Invoke backward function: {backward.Name}");
  94. var gradients = backward.CallFlat(processed_args, remapped_captures);
  95. foreach (var unneeded_gradient_index in unneeded_gradients)
  96. {
  97. var index = Convert.ToInt32(unneeded_gradient_index);
  98. if (gradients.Length <= index)
  99. gradients.Insert(index, null);
  100. }
  101. return gradients;
  102. };
  103. return (_backward_function_wrapper, recorded_outputs);
  104. }
  105. protected (EagerDefinedFunction, FuncGraph, ConcreteFunction, List<int>, int)
  106. BuildFunctionsForOutputs(Tensors outputs, Tensors inference_args)
  107. {
  108. var trainable_outputs = new List<Tensor>();
  109. var trainable_indices = new List<int>();
  110. foreach(var (index, output) in enumerate(outputs))
  111. {
  112. if (gradients_util.IsTrainable(output))
  113. {
  114. trainable_outputs.Add(output);
  115. trainable_indices.Add(index);
  116. }
  117. }
  118. var gradients_wrt_outputs = new List<Tensor>();
  119. var backwards_graph = new FuncGraph($"{_BACKWARD_PREFIX}_{_func_graph.FuncName}_{ops.uid()}");
  120. backwards_graph.as_default();
  121. foreach (var output in trainable_outputs)
  122. gradients_wrt_outputs.Add(tf.placeholder(output.dtype, output.shape));
  123. var gradients_wrt_inputs = gradients_util._GradientsHelper(trainable_outputs.ToArray(),
  124. _func_graph.Inputs,
  125. grad_ys: gradients_wrt_outputs.ToArray(),
  126. src_graph: _func_graph);
  127. var captures_from_forward = backwards_graph.external_captures
  128. .Where(x => x.IsCreatedInGraphMode && x.graph == _func_graph)
  129. .ToArray();
  130. foreach(var capture in captures_from_forward)
  131. {
  132. if (!_func_graph.Outputs.Contains(capture))
  133. _func_graph.Outputs.Add(capture);
  134. }
  135. backwards_graph.Exit();
  136. var forward_function_name = $"{_FORWARD_PREFIX}_{_func_graph.FuncName}_{ops.uid()}";
  137. var backward_function_attr = new Dictionary<string, string>();
  138. backward_function_attr[FORWARD_FUNCTION_ATTRIBUTE_NAME] = forward_function_name;
  139. gradients_wrt_outputs.append(backwards_graph.internal_captures);
  140. backwards_graph.Inputs = gradients_wrt_outputs;
  141. backwards_graph.Outputs = gradients_wrt_inputs;
  142. var backward_function = new ConcreteFunction(backwards_graph, backward_function_attr);
  143. var forward_function_attr = new Dictionary<string, string>();
  144. forward_function_attr[BACKWARD_FUNCTION_ATTRIBUTE_NAME] = backward_function.Name;
  145. var forward_function = new EagerDefinedFunction(forward_function_name, _func_graph,
  146. _func_graph.Inputs, _func_graph.Outputs, forward_function_attr);
  147. return (forward_function, _func_graph, backward_function, null, 0);
  148. }
  149. public virtual EagerDefinedFunction ForwardAndBackwardFunctions(Tensors inference_args)
  150. {
  151. throw new NotImplementedException("");
  152. }
  153. }
  154. }