You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

FuncGraph.cs 9.3 kB

5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. using Google.Protobuf;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Linq;
  5. using Tensorflow.Eager;
  6. using Tensorflow.Exceptions;
  7. using static Tensorflow.Binding;
  8. namespace Tensorflow.Graphs
  9. {
  10. /// <summary>
  11. /// Graph representing a function body.
  12. /// </summary>
  13. public class FuncGraph : Graph
  14. {
  15. IntPtr _func_graph_handle;
  16. public string FuncName => _graph_key;
  17. public Tensors Inputs { get; set; } = new Tensors();
  18. public Tensors Outputs { get; set; } = new Tensors();
  19. public Dictionary<string, string> Attrs { get; set; }
  20. Dictionary<long, (Tensor, Tensor)> _captures
  21. = new Dictionary<long, (Tensor, Tensor)>();
  22. public Tensor[] external_captures
  23. => _captures.Select(x => x.Value.Item1).ToArray();
  24. public (Tensor, Tensor)[] captures
  25. => _captures.Values.Select(x => x).ToArray();
  26. public Tensor[] internal_captures
  27. => _captures.Select(x => x.Value.Item2).ToArray();
  28. public Tensor[] captured_inputs
  29. => external_captures;
  30. /// <summary>
  31. /// Construct a new FuncGraph.
  32. /// </summary>
  33. public FuncGraph(string name) : base()
  34. {
  35. outer_graph = ops.get_default_graph();
  36. while (outer_graph.building_function)
  37. outer_graph = outer_graph.OuterGraph;
  38. _graph_key = name;
  39. building_function = true;
  40. }
  41. public FuncGraph(IntPtr handle, string name, Dictionary<string, string> attrs) : base()
  42. {
  43. outer_graph = ops.get_default_graph();
  44. while (outer_graph.building_function)
  45. outer_graph = outer_graph.OuterGraph;
  46. _graph_key = name;
  47. building_function = true;
  48. Attrs = attrs;
  49. // Will to test if FuncGraph has memory leak
  50. // c_api.TF_DeleteGraph(_handle);
  51. _handle = handle;
  52. }
  53. public void ToGraph(Operation[] opers,
  54. Tensor[] inputs, Tensor[] outputs,
  55. string[] output_names)
  56. {
  57. var status = new Status();
  58. _func_graph_handle = c_api.TF_GraphToFunction(_handle,
  59. _graph_key,
  60. false,
  61. opers.Length,
  62. opers.Select(x => (IntPtr)x).ToArray(),
  63. inputs.Length,
  64. inputs.Select(x => new TF_Output(x.op, 0)).ToArray(),
  65. outputs.Length,
  66. outputs.Select(x => new TF_Output(x.op, 0)).ToArray(),
  67. output_names == null || output_names.Length == 0 ? null : output_names,
  68. IntPtr.Zero,
  69. null,
  70. status.Handle);
  71. status.Check(true);
  72. SetAttrs();
  73. // c_api.TF_GraphCopyFunction(outer_graph, _func_graph_handle, IntPtr.Zero, status.Handle);
  74. // status.Check(true);
  75. c_api.TFE_ContextAddFunction(tf.Context.Handle, _func_graph_handle, status.Handle);
  76. status.Check(true);
  77. _graph_key = c_api.StringPiece(c_api.TF_FunctionName(_func_graph_handle));
  78. Inputs = inputs;
  79. // mark_as_return
  80. Outputs = outputs;// .Select(x => array_ops.identity(x)).ToArray();
  81. }
  82. public override Operation create_op(string op_type, Tensor[] inputs, TF_DataType[] dtypes, TF_DataType[] input_types = null, string name = null, Dictionary<string, AttrValue> attrs = null, OpDef op_def = null, bool compute_device = true)
  83. {
  84. foreach(var (i, inp) in enumerate(inputs))
  85. inputs[i] = capture(inp);
  86. return base.create_op(op_type, inputs, dtypes, input_types, name, attrs, op_def, compute_device);
  87. }
  88. const int _EAGER_CONST_THRESHOLD = 128;
  89. public Tensor capture(Tensor tensor, string name = null, Shape shape = null)
  90. {
  91. if(tensor is EagerTensor)
  92. {
  93. if (name == null)
  94. name = ops.uid().ToString();
  95. // Small EagerTensors are captured with Const ops
  96. if (dtypes.is_value_dtype(tensor.dtype)
  97. && (tensor.rank == 0 || tensor.size < _EAGER_CONST_THRESHOLD))
  98. return capture_eager_tensor(tensor, name);
  99. // Large EagerTensors and resources are captured with Placeholder ops
  100. return _capture_helper(tensor, name, shape: shape);
  101. }
  102. if(tensor.graph != this)
  103. {
  104. if (name == null)
  105. name = tensor.op.name;
  106. var inner_graph = tensor.graph;
  107. while(inner_graph != null && inner_graph is FuncGraph inner_func_graph)
  108. {
  109. if (inner_graph == this)
  110. throw new InaccessibleTensorError($"The tensor '{tensor.name}' cannot be accessed here: it is defined" +
  111. " in another function or code block. Use return values," +
  112. " explicit Python locals or TensorFlow collections to access" +
  113. $" it. Defined in: {tensor.graph.graph_key}; accessed from: {graph_key}.");
  114. inner_graph = inner_func_graph.outer_graph;
  115. }
  116. return _capture_helper(tensor, name);
  117. }
  118. return tensor;
  119. }
  120. Tensor capture_eager_tensor(Tensor tensor, string name)
  121. {
  122. Tensor graph_const = null;
  123. if (!_captures.ContainsKey(tensor.Id))
  124. {
  125. graph_const = tf_with(ops.control_dependencies(null), ctl
  126. => constant_op.constant(tensor.numpy(), dtype: tensor.dtype, shape: tensor.shape, name: name));
  127. add_capture(tensor, graph_const);
  128. }
  129. else
  130. {
  131. graph_const = _captures[tensor.Id].Item2;
  132. }
  133. BackwardFunction _backward_function_wrapper = (output_grads, unneeded_gradients) =>
  134. {
  135. return output_grads;
  136. };
  137. tf.Runner.RecordGradient("captured_value",
  138. new[] { graph_const }, null,
  139. new[] { tensor },
  140. getBackwardFunction: () => _backward_function_wrapper
  141. /*getForwardFunction: forward_function*/);
  142. return graph_const;
  143. }
  144. Tensor _capture_helper(Tensor tensor, string name, Shape shape = null)
  145. {
  146. Tensor placeholder = null;
  147. if (!_captures.ContainsKey(tensor.Id))
  148. {
  149. placeholder = _create_substitute_placeholder(tensor,
  150. name: name,
  151. dtype: tensor.dtype,
  152. shape: shape);
  153. add_capture(tensor, placeholder);
  154. }
  155. else
  156. {
  157. placeholder = _captures[tensor.Id].Item2;
  158. }
  159. BackwardFunction _backward_function_wrapper = (output_grads, unneeded_gradients) =>
  160. {
  161. return output_grads;
  162. };
  163. tf.Runner.RecordGradient("captured_value",
  164. new[] { placeholder }, null,
  165. new[] { tensor },
  166. getBackwardFunction: () => _backward_function_wrapper
  167. /*getForwardFunction: forward_function*/);
  168. return placeholder;
  169. }
  170. void add_capture(Tensor tensor, Tensor placeholder)
  171. {
  172. _captures.Add(tensor.Id, (tensor, placeholder));
  173. Inputs.Add(placeholder);
  174. }
  175. Tensor _create_substitute_placeholder(Tensor value,
  176. string name = null,
  177. TF_DataType dtype = TF_DataType.DtInvalid,
  178. Shape shape = null)
  179. {
  180. if (shape is null)
  181. shape = value.shape;
  182. if (dtype == TF_DataType.DtInvalid)
  183. dtype = value.dtype;
  184. var placeholder = tf_with(ops.control_dependencies(null), ctl
  185. => array_ops.placeholder(dtype, shape: shape, name: name));
  186. // custom_gradient.copy_handle_data(value, placeholder)
  187. return placeholder;
  188. }
  189. void SetAttrs()
  190. {
  191. if (Attrs == null)
  192. return;
  193. foreach (var (_name, attr_value) in enumerate(Attrs))
  194. {
  195. var serialized = new AttrValue
  196. {
  197. S = ByteString.CopyFromUtf8(attr_value)
  198. }.ToByteArray();
  199. c_api.TF_FunctionSetAttrValueProto(_func_graph_handle, _name, serialized, serialized.Length, tf.Status.Handle);
  200. tf.Status.Check(true);
  201. }
  202. }
  203. public override Graph as_default()
  204. {
  205. tf.Context.graph_mode(isFunc: true);
  206. ops.set_default_graph(this);
  207. return this;
  208. }
  209. public override void Exit()
  210. {
  211. tf.Context.restore_mode();
  212. ops.pop_graph();
  213. }
  214. protected override void DisposeUnmanagedResources(IntPtr handle)
  215. {
  216. c_api.TFE_ContextRemoveFunction(tf.Context.Handle, _graph_key, tf.Status.Handle);
  217. c_api.TF_DeleteFunction(_func_graph_handle);
  218. base.DisposeUnmanagedResources(handle);
  219. }
  220. }
  221. }