You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

AutoGraph.cs 2.8 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. using System;
  2. using System.Diagnostics;
  3. using System.Linq;
  4. using static Tensorflow.Binding;
  5. namespace Tensorflow.Graphs
  6. {
  7. public class AutoGraph
  8. {
  9. public Func<Tensor, Tensor> to_graph(Func<Tensor, Tensor> func, TF_DataType dtype = TF_DataType.TF_INT32)
  10. {
  11. string func_name = $"{func.Method.Name}_{ops.uid_function()}";
  12. var graph = new FuncGraph(func_name);
  13. graph.as_default();
  14. var input = tf.placeholder(dtype);
  15. var output = func(input);
  16. var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();
  17. graph.ToGraph(opers,
  18. new[] { input },
  19. new[] { output },
  20. null);
  21. graph.Exit();
  22. return (Tensor input) =>
  23. {
  24. if (tf.executing_eagerly())
  25. {
  26. var result = tf.Runner.TFE_Execute(tf.Context,
  27. tf.Context.DeviceName,
  28. func_name,
  29. new[] { input },
  30. null,
  31. 1);
  32. return result[0];
  33. }
  34. using (var s = tf.Session(input.graph))
  35. {
  36. var output = func(input);
  37. return output;
  38. }
  39. };
  40. }
  41. public Func<Tensor, Tensor, Tensor> to_graph(Func<Tensor, Tensor, Tensor> func, params TF_DataType[] dtypes)
  42. {
  43. string func_name = $"{func.Method.Name}_{ops.uid_function()}";
  44. var graph = new FuncGraph(func_name);
  45. graph.as_default();
  46. var input1 = tf.placeholder(dtypes.Length >= 1 ? dtypes[0] : tf.int32);
  47. var input2 = tf.placeholder(dtypes.Length >= 2 ? dtypes[1] : tf.int32);
  48. var output = func(input1, input2);
  49. var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();
  50. graph.ToGraph(opers,
  51. new[] { input1, input2 },
  52. new[] { output },
  53. null);
  54. graph.Exit();
  55. return (Tensor a, Tensor b) =>
  56. {
  57. if (tf.executing_eagerly())
  58. {
  59. var result = tf.Runner.TFE_Execute(tf.Context,
  60. tf.Context.DeviceName,
  61. func_name,
  62. new[] { a, b },
  63. null,
  64. 1);
  65. return result[0];
  66. }
  67. using (var s = tf.Session(a.graph))
  68. {
  69. Debug.Assert(a.graph == b.graph);
  70. var output = func(a, b);
  71. return output;
  72. }
  73. };
  74. }
  75. }
  76. }