You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ConcreteFunction.cs 5.9 kB

4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using Tensorflow.Framework.Models;
  5. using Tensorflow.Graphs;
  6. using static Tensorflow.Binding;
  7. namespace Tensorflow.Functions
  8. {
  9. /// <summary>
  10. ///
  11. /// </summary>
  12. public class ConcreteFunction
  13. {
  14. FuncGraph func_graph;
  15. public Tensor[] Inputs => func_graph.Inputs;
  16. public Tensor[] CapturedInputs => func_graph.external_captures;
  17. public string Name => func_graph?.FuncName;
  18. public Tensor[] Outputs;
  19. public Type ReturnType;
  20. public TensorSpec[] OutputStructure;
  21. public ConcreteFunction(string name)
  22. {
  23. func_graph = new FuncGraph(name);
  24. }
  25. public ConcreteFunction(FuncGraph graph, Dictionary<string, string> attrs = null)
  26. {
  27. func_graph = graph;
  28. ToGraph(graph.Inputs, graph.Outputs.Where(x => x != null).ToArray());
  29. }
  30. public ConcreteFunction(Func<Tensor, Tensor> func, TF_DataType dtype)
  31. {
  32. string func_name = $"{func.Method.Name}_{Guid.NewGuid()}";
  33. func_graph = new FuncGraph(func_name);
  34. func_graph.as_default();
  35. var input = tf.placeholder(dtype);
  36. var output = func(input);
  37. var opers = func_graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();
  38. func_graph.ToGraph(opers,
  39. new[] { input },
  40. new[] { output },
  41. null);
  42. func_graph.Exit();
  43. }
  44. public ConcreteFunction(Func<Tensor, IDatasetV2> func, TF_DataType dtype)
  45. {
  46. string func_name = $"{func.Method.Name}_{Guid.NewGuid()}";
  47. func_graph = new FuncGraph(func_name);
  48. func_graph.as_default();
  49. var input = tf.placeholder(dtype);
  50. var output = func(input);
  51. OutputStructure = output.structure;
  52. var opers = func_graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();
  53. func_graph.ToGraph(opers,
  54. new[] { input },
  55. new[] { output.variant_tensor },
  56. null);
  57. func_graph.Exit();
  58. }
  59. public ConcreteFunction(Func<Tensors, Tensors> func,
  60. TF_DataType[] dtypes, TensorShape[] shapes)
  61. {
  62. string func_name = $"{func.Method.Name}_{Guid.NewGuid()}";
  63. // IntPtr func_handle;
  64. func_graph = new FuncGraph(func_name);
  65. func_graph.as_default();
  66. var inputs = new Tensors();
  67. foreach(var (i, dtype) in enumerate(dtypes))
  68. inputs.Add(tf.placeholder(dtypes[i], shape: shapes[i], name: "args"));
  69. Outputs = func(inputs);
  70. OutputStructure = Outputs.Select(x => x.ToTensorSpec()).ToArray();
  71. var opers = func_graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();
  72. func_graph.ToGraph(opers, inputs, Outputs, null);
  73. func_graph.Exit();
  74. }
  75. public void ToGraph(Tensors inputs, Tensors outputs)
  76. {
  77. var opers = func_graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();
  78. func_graph.ToGraph(opers,
  79. inputs,
  80. outputs,
  81. null);
  82. OutputStructure = outputs.Select(x => x.ToTensorSpec()).ToArray();
  83. }
  84. public void Enter()
  85. {
  86. func_graph.as_default();
  87. }
  88. public void Exit()
  89. {
  90. func_graph.Exit();
  91. }
  92. public Tensors FilteredCall(Tensors inputs)
  93. {
  94. return CallFlat(inputs, CapturedInputs);
  95. }
  96. /// <summary>
  97. /// Executes the wrapped function.
  98. /// </summary>
  99. /// <param name="args"></param>
  100. /// <param name="captured_inputs"></param>
  101. /// <returns></returns>
  102. public Tensors CallFlat(Tensor[] args, Tensor[] captured_inputs)
  103. {
  104. var executing_eagerly = tf.Context.executing_eagerly();
  105. var default_graph = ops.get_default_graph();
  106. var tensor_inputs = new Tensors();
  107. foreach (var (i, arg) in enumerate(args))
  108. {
  109. tensor_inputs.Add(arg);
  110. // If we're graph building, shape inference is on.
  111. if (!executing_eagerly)
  112. {
  113. }
  114. }
  115. tensor_inputs.AddRange(captured_inputs);
  116. args = tensor_inputs.ToArray();
  117. var possible_gradient_type = tf.Runner.MustRecordGradient() ? 1 : 0;
  118. // No tape is watching; skip to running the function.
  119. if (possible_gradient_type == 0 && executing_eagerly)
  120. {
  121. var attrs = new object[]
  122. {
  123. "executor_type", "",
  124. "config_proto", tf.Context.FunctionCallOptions.config_proto_serialized()
  125. };
  126. return tf.Runner.Execute(tf.Context, func_graph.FuncName, func_graph.Outputs.Length, args, attrs);
  127. }
  128. var forward_backward = SelectForwardAndBackwardFunctions(args, possible_gradient_type, executing_eagerly);
  129. var (forward_function, args_with_tangents) = forward_backward.Forward();
  130. Tensors flat_outputs = null;
  131. if (executing_eagerly)
  132. flat_outputs = forward_function.Call(args_with_tangents);
  133. forward_backward.Record(flat_outputs);
  134. return flat_outputs;
  135. }
  136. ForwardBackwardCall SelectForwardAndBackwardFunctions(Tensors args, int possible_gradient_type, bool executing_eagerly)
  137. {
  138. var functions = new FirstOrderTapeGradientFunctions(func_graph, false);
  139. return new ForwardBackwardCall(functions, args, tape_watching: true);
  140. }
  141. public override string ToString()
  142. => Name;
  143. }
  144. }