You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

EagerDefinedFunction.cs 7.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. using Google.Protobuf;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.IO;
  5. using System.Linq;
  6. using System.Text;
  7. using Tensorflow.Contexts;
  8. using Tensorflow.Eager;
  9. using Tensorflow.Graphs;
  10. using Tensorflow.Operations;
  11. using Tensorflow.Util;
  12. using Tensorflow.Common.Extensions;
  13. using static Tensorflow.Binding;
  14. using Tensorflow.Framework;
  15. using System.Buffers;
  16. using Tensorflow.Gradients;
  17. namespace Tensorflow.Functions
  18. {
  19. public class EagerDefinedFunction: IDisposable
  20. {
  21. public int _num_outputs;
  22. FuncGraph _graph;
  23. FunctionDef _definition;
  24. OpDef _signature;
  25. string _name;
  26. internal ScopedTFFunction _c_func;
  27. internal Tensor[] _func_graph_outputs;
  28. internal string _grad_func_name;
  29. internal Func<Operation, Tensor[], Tensor[]> csharp_grad_func;
  30. internal EagerDefinedFunction _grad_func;
  31. internal bool _registered_on_context = false;
  32. public string Name => _name;
  33. public DataType[] OutputTypes { get; protected set; }
  34. public Shape[] OutputShapes { get; protected set; }
  35. public FunctionDef Definition
  36. {
  37. get
  38. {
  39. if(_definition is null)
  40. {
  41. _definition = _get_definition();
  42. }
  43. return _definition;
  44. }
  45. }
  46. public OpDef Signature
  47. {
  48. get
  49. {
  50. if( _signature is null)
  51. {
  52. _signature = Definition.Signature;
  53. }
  54. return _signature;
  55. }
  56. }
  57. public unsafe EagerDefinedFunction(string name, FuncGraph graph,
  58. Tensors inputs, Tensors outputs,
  59. Dictionary<string, AttrValue> attrs)
  60. {
  61. var input_ops = inputs.Select(x => x.op).ToArray();
  62. var operations = graph.get_operations().Where(x => !input_ops.Contains(x.op))
  63. .Select(x => x as Operation).ToArray();
  64. var graph_output_names = graph._output_names;
  65. string[] output_names;
  66. if(graph_output_names is not null && outputs.All(t => graph_output_names.ContainsKey(ops.tensor_id(t))))
  67. {
  68. output_names = outputs.Select(t => graph_output_names[ops.tensor_id(t)]).ToArray();
  69. if(output_names.Distinct().Count() != output_names.Length)
  70. {
  71. output_names = new string[0];
  72. }
  73. }
  74. else
  75. {
  76. output_names = new string[0];
  77. }
  78. Status status = new Status();
  79. var fn = c_api.TF_GraphToFunction(graph.c_graph,
  80. name,
  81. false,
  82. operations.Length,
  83. operations.Length == 0 ? new IntPtr[0] : operations.Select(x => (IntPtr)x).ToArray(),
  84. inputs.Length,
  85. inputs.Select(t => t._as_tf_output()).ToArray(),
  86. outputs.Length,
  87. outputs.Select(t => t._as_tf_output()).ToArray(),
  88. output_names.Length != outputs.Length ? null : output_names,
  89. IntPtr.Zero, // warning: the control output hasbben totally ignored.
  90. null,
  91. status);
  92. status.Check(true);
  93. _c_func = new ScopedTFFunction(fn, name);
  94. foreach(var (attr_name, attr_value) in attrs)
  95. {
  96. var serialized = attr_value.ToByteArray();
  97. c_api.TF_FunctionSetAttrValueProto(fn, attr_name, serialized, serialized.Length, status);
  98. status.Check(true);
  99. }
  100. var signature = _get_definition().Signature;
  101. _name = signature.Name;
  102. tf_with(ops.init_scope(), s =>
  103. {
  104. tf.Context.add_function(fn);
  105. _registered_on_context = true;
  106. });
  107. _num_outputs = signature.OutputArg.Count;
  108. OutputTypes = signature.OutputArg.Select(x => x.Type).ToArray();
  109. OutputShapes = outputs.Select(x => x.shape).ToArray();
  110. _func_graph_outputs = new List<Tensor>(outputs).ToArray();
  111. csharp_grad_func = null;
  112. _graph = graph;
  113. }
  114. public unsafe Tensors Call(Tensors args)
  115. {
  116. // TODO(Rinne): Add arg `CancellationManager`.
  117. // TODO(Rinne): Check the arg length.
  118. var function_call_options = tf.Context.FunctionCallOptions;
  119. string config = ""; // TODO(Rinne): revise it. The following code should work but not, for unclear reasons.
  120. //if (function_call_options.config_proto_serialized().Length == 0)
  121. //{
  122. // config = function_utils.get_disabled_rewriter_config().ToStringUtf8();
  123. //}
  124. //else
  125. //{
  126. // config = function_call_options.config_proto_serialized().ToStringUtf8();
  127. //}
  128. string executor_type = function_call_options.ExecutorType ?? "";
  129. var executing_eagerly = tf.Context.executing_eagerly();
  130. var attrs = new object[]
  131. {
  132. "executor_type", executor_type,
  133. "config_proto", config
  134. };
  135. Tensor[] outputs;
  136. if (executing_eagerly)
  137. {
  138. outputs = execute.executes(
  139. Signature.Name,
  140. _num_outputs,
  141. args,
  142. attrs,
  143. tf.Context);
  144. }
  145. else
  146. {
  147. if(tf.GetTapeSet().Count == 0)
  148. {
  149. outputs = functional_ops.partitioned_call(args, this, OutputTypes,
  150. executing_eagerly, config, "");
  151. }
  152. else
  153. {
  154. var tape = tf.GetTapeSet().Peek();
  155. tape.StopRecord();
  156. outputs = functional_ops.partitioned_call(args, this, OutputTypes,
  157. executing_eagerly, config, "");
  158. tape.StartRecord();
  159. }
  160. }
  161. foreach(var (i, func_graph_output) in enumerate(_func_graph_outputs))
  162. {
  163. handle_data_util.copy_handle_data(func_graph_output, outputs[i]);
  164. }
  165. if (executing_eagerly)
  166. {
  167. return outputs;
  168. }
  169. else
  170. {
  171. foreach(var (i, shape) in enumerate(OutputShapes))
  172. {
  173. outputs[i].shape = shape;
  174. }
  175. return outputs;
  176. }
  177. }
  178. public void AddToGraph(Graph g = null)
  179. {
  180. if(g is null && tf.Context.executing_eagerly())
  181. {
  182. var ctx = tf.Context;
  183. if (!ctx.has_function(this.Name))
  184. {
  185. ctx.add_function_def(Definition);
  186. }
  187. }
  188. else
  189. {
  190. if (!g.IsFunction(Name))
  191. {
  192. g.AddFunction(this);
  193. }
  194. foreach(var f in _graph.Functions.Values)
  195. {
  196. if (!g.IsFunction(f.Name))
  197. {
  198. g.AddFunction(f);
  199. }
  200. }
  201. }
  202. }
  203. private FunctionDef _get_definition()
  204. {
  205. var buffer = c_api_util.tf_buffer();
  206. Status status = new();
  207. c_api.TF_FunctionToFunctionDef(_c_func.Get(), buffer, status);
  208. status.Check(true);
  209. var proto_data = c_api.TF_GetBuffer(buffer);
  210. return FunctionDef.Parser.ParseFrom(proto_data.AsSpan<byte>());
  211. }
  212. public void Dispose()
  213. {
  214. tf.Context.remove_function(Name);
  215. }
  216. }
  217. }