You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

FuncGraph.cs 1.9 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using System.Runtime.InteropServices;
  5. using System.Text;
  6. using Tensorflow.Functions;
  7. using static Tensorflow.Binding;
  8. namespace Tensorflow.Graphs
  9. {
  10. /// <summary>
  11. /// Graph representing a function body.
  12. /// </summary>
  13. public class FuncGraph : Graph
  14. {
  15. List<Operation> inputs;
  16. List<Operation> outputs;
  17. Graph outer_graph;
  18. string func_name;
  19. IntPtr func_handle;
  20. public string FuncName => c_api.StringPiece(c_api.TF_FunctionName(func_handle));
  21. /// <summary>
  22. /// Construct a new FuncGraph.
  23. /// </summary>
  24. public FuncGraph(string name) : base()
  25. {
  26. outer_graph = ops.get_default_graph();
  27. func_name = name;
  28. }
  29. public IntPtr ToGraph(Operation[] opers,
  30. Operation[] inputs, Operation[] outputs,
  31. string[] output_names)
  32. {
  33. using var status = new Status();
  34. func_handle = c_api.TF_GraphToFunction(_handle,
  35. func_name,
  36. false,
  37. opers.Length,
  38. opers.Select(x => (IntPtr)x).ToArray(),
  39. inputs.Length,
  40. inputs.Select(x => new TF_Output(x, 0)).ToArray(),
  41. outputs.Length,
  42. outputs.Select(x => new TF_Output(x, 0)).ToArray(),
  43. output_names == null || output_names.Length == 0 ? null : output_names,
  44. IntPtr.Zero,
  45. null,
  46. status.Handle);
  47. status.Check(true);
  48. c_api.TF_GraphCopyFunction(outer_graph, func_handle, IntPtr.Zero, status.Handle);
  49. status.Check(true);
  50. c_api.TFE_ContextAddFunction(tf.Context.Handle, func_handle, status.Handle);
  51. status.Check(true);
  52. return func_handle;
  53. }
  54. }
  55. }