You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Execute.cs 3.3 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. using System.Collections.Generic;
  2. using System;
  3. using System.Linq;
  4. namespace Tensorflow.Eager
  5. {
  6. public class Execute
  7. {
  8. /// <summary>
  9. /// Execute a TensorFlow operation.
  10. /// </summary>
  11. /// <param name="op_name">
  12. /// Name of the TensorFlow operation (see REGISTER_OP in C++ code) to
  13. /// execute.
  14. /// </param>
  15. /// <param name="num_outputs">
  16. /// The number of outputs of the operation to fetch.
  17. /// </param>
  18. /// <param name="inputs">
  19. /// A list of inputs to the operation. Each entry should be a Tensor, or
  20. /// a value which can be passed to the Tensor constructor to create one.
  21. /// </param>
  22. /// <param name="attrs">
  23. /// A tuple with alternating string attr names and attr values for this
  24. /// operation.
  25. /// </param>
  26. /// <param name="ctx">The value of context.context().</param>
  27. /// <param name="name">Customized name for the operation.</param>
  28. /// <returns>List of output Tensor objects. The list is empty if there are no outputs</returns>
  29. public Tensor execute(Context ctx, string op_name, int num_outputs,
  30. Tensor[] inputs, object[] attrs,
  31. string name = null)
  32. {
  33. ctx.ensure_initialized();
  34. // TFE_TensorHandle
  35. using var status = new Status();
  36. /*var retVals = wrap_tfe_src.TFE_Execute(ctx, ctx.device_name, op_name, inputs, attrs, num_outputs, status);
  37. return new EagerTensor((TFE_TensorHandle)retVals[0]);*/
  38. IntPtr[] outputs = new IntPtr[num_outputs];
  39. c_api.TFE_QuickExecute(ctx,
  40. ctx.device_name,
  41. op_name,
  42. inputs.Select(x => (x as EagerTensor).GetTfeTensorHandle()).ToArray(),
  43. inputs.Length,
  44. op => wrap_tfe_src.SetOpAttrs(op, attrs),
  45. outputs,
  46. num_outputs,
  47. status);
  48. status.Check(true);
  49. TFE_TensorHandle tfe_tensor_handle = outputs[0];
  50. return new EagerTensor(tfe_tensor_handle);
  51. }
  52. public (TF_DataType, Tensor[]) args_to_matching_eager(Context ctx, TF_DataType default_dtype = TF_DataType.DtInvalid, object[] args = null)
  53. {
  54. if (args.Length == 0 && default_dtype != TF_DataType.DtInvalid)
  55. return (default_dtype, null);
  56. if (args.Count(x => x is EagerTensor) == args.Length)
  57. return ((args[0] as EagerTensor).dtype, args.Select(x => x as EagerTensor).ToArray());
  58. var dtype = TF_DataType.DtInvalid;
  59. foreach (var x in args)
  60. {
  61. if (x is EagerTensor et)
  62. dtype = et.dtype;
  63. }
  64. if (dtype == TF_DataType.DtInvalid)
  65. {
  66. var ret = new List<Tensor>();
  67. foreach (var t in args)
  68. {
  69. ret.Add(ops.convert_to_tensor(t, dtype, preferred_dtype: default_dtype, ctx: ctx));
  70. if (dtype == TF_DataType.DtInvalid)
  71. dtype = ret.Last().dtype;
  72. }
  73. return (dtype, ret.ToArray());
  74. }
  75. else
  76. throw new NotImplementedException("");
  77. }
  78. }
  79. }