You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

AutoGraphAttribute.cs 2.6 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. /*using MethodBoundaryAspect.Fody.Attributes;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Linq;
  5. using System.Text;
  6. using Tensorflow.Eager;
  7. using static Tensorflow.Binding;
  8. namespace Tensorflow.Graphs
  9. {
  10. [AllowChangingInputArguments]
  11. public sealed class AutoGraphAttribute : OnMethodBoundaryAspect
  12. {
  13. FuncGraph graph;
  14. Tensor[] originalInputs;
  15. string func_name;
  16. static Dictionary<string, Func<Tensor[], Tensor>> functions = new Dictionary<string, Func<Tensor[], Tensor>>();
  17. public override void OnEntry(MethodExecutionArgs args)
  18. {
  19. func_name = $"autograph_{args.Instance}.{args.Method.Name}";
  20. if (functions.ContainsKey(func_name))
  21. {
  22. args.ReturnValue = functions[func_name](args.Arguments.Select(x => x as Tensor).ToArray());
  23. args.FlowBehavior = FlowBehavior.Return;
  24. return;
  25. }
  26. tf.compat.v1.disable_eager_execution();
  27. // make function as an Operation by autograph
  28. graph = new FuncGraph(func_name);
  29. graph.as_default();
  30. originalInputs = new Tensor[args.Arguments.Length];
  31. // convert args to placeholder
  32. for (var i = 0; i < args.Arguments.Length; i++)
  33. {
  34. if (args.Arguments[i] is EagerTensor tensor)
  35. {
  36. originalInputs[i] = tensor;
  37. args.Arguments[i] = tf.placeholder(tensor.dtype, shape: tensor.TensorShape);
  38. }
  39. }
  40. }
  41. public override void OnExit(MethodExecutionArgs args)
  42. {
  43. var output = (Tensor)args.ReturnValue;
  44. var inputs = args.Arguments.Select(x => x as Tensor).ToArray();
  45. var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();
  46. graph.ToGraph(opers,
  47. inputs.Select(x => x.op).ToArray(),
  48. new Operation[] { output.op },
  49. null);
  50. graph.Dispose();
  51. tf.enable_eager_execution();
  52. Func<Tensor[], Tensor> function = (x) =>
  53. {
  54. var result = tf.Runner.TFE_Execute(tf.Context,
  55. tf.Context.DeviceName,
  56. func_name,
  57. x,
  58. null,
  59. 1);
  60. return result[0];
  61. };
  62. // cache function.
  63. functions[func_name] = function;
  64. // run function
  65. args.ReturnValue = function(originalInputs);
  66. }
  67. }
  68. }*/