You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Operation.cs 4.3 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Runtime.InteropServices;
  4. using System.Text;
  5. namespace Tensorflow
  6. {
  7. public class Operation
  8. {
  9. private readonly IntPtr _handle;
  10. public Graph Graph { get; }
  11. public int _id => _id_value;
  12. private int _id_value;
  13. private Status status = new Status();
  14. public string Name => c_api.StringPiece(c_api.TF_OperationName(_handle));
  15. public string OpType => c_api.StringPiece(c_api.TF_OperationOpType(_handle));
  16. public string Device => c_api.StringPiece(c_api.TF_OperationDevice(_handle));
  17. public int NumOutputs => c_api.TF_OperationNumOutputs(_handle);
  18. public TF_DataType OutputType(int index) => c_api.TF_OperationOutputType(new TF_Output(_handle, index));
  19. public int OutputListLength(string name) => c_api.TF_OperationOutputListLength(_handle, name, status);
  20. public TF_Output Input(int index) => c_api.TF_OperationInput(new TF_Input(_handle, index));
  21. public TF_DataType InputType(int index) => c_api.TF_OperationInputType(new TF_Input(_handle, index));
  22. public int InputListLength(string name) => c_api.TF_OperationInputListLength(_handle, name, status);
  23. public int NumInputs => c_api.TF_OperationNumInputs(_handle);
  24. public int OutputNumConsumers(int index) => c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, index));
  25. public unsafe TF_Input[] OutputConsumers(int index, int max_consumers)
  26. {
  27. int size = Marshal.SizeOf<TF_Input>();
  28. var handle = (TF_Input*)Marshal.AllocHGlobal(size);
  29. int num = c_api.TF_OperationOutputConsumers(new TF_Output(_handle, index), handle, max_consumers);
  30. var consumers = new TF_Input[num];
  31. for(int i = 0; i < num; i++)
  32. {
  33. consumers[i] = new TF_Input((*handle).oper + i * size, (*handle).index);
  34. }
  35. return consumers;
  36. }
  37. public int NumControlInputs => c_api.TF_OperationNumControlInputs(_handle);
  38. public int NumControlOutputs => c_api.TF_OperationNumControlOutputs(_handle);
  39. private Tensor[] _outputs;
  40. public Tensor[] outputs => _outputs;
  41. public Tensor[] inputs;
  42. public Operation(IntPtr handle)
  43. {
  44. if (handle == IntPtr.Zero)
  45. return;
  46. _handle = handle;
  47. }
  48. public Operation(Graph g, string opType, string oper_name)
  49. {
  50. Graph = g;
  51. var desc = c_api.TF_NewOperation(g, opType, oper_name);
  52. c_api.TF_SetAttrType(desc, "dtype", TF_DataType.TF_INT32);
  53. c_api.TF_FinishOperation(desc, status);
  54. }
  55. public Operation(NodeDef node_def, Graph g, List<Tensor> inputs = null, TF_DataType[] output_types = null, object control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null)
  56. {
  57. Graph = g;
  58. _id_value = Graph._next_id();
  59. _handle = ops._create_c_op(g, node_def, inputs);
  60. _outputs = new Tensor[NumOutputs];
  61. for (int i = 0; i < NumOutputs; i++)
  62. {
  63. _outputs[i] = new Tensor(this, i, output_types[i]);
  64. }
  65. Graph._add_op(this);
  66. }
  67. public object get_attr(string name)
  68. {
  69. object ret = null;
  70. var fields = new string[] { "s", "i", "f", "b", "type", "shape", "tensor", "func" };
  71. switch (name)
  72. {
  73. case "dtype":
  74. ret = _outputs[0];
  75. break;
  76. case "shape":
  77. ret = new TensorShapeProto();
  78. break;
  79. }
  80. return ret;
  81. }
  82. public static implicit operator Operation(IntPtr handle)
  83. {
  84. return new Operation(handle);
  85. }
  86. public static implicit operator IntPtr(Operation op)
  87. {
  88. return op._handle;
  89. }
  90. public override bool Equals(object obj)
  91. {
  92. switch (obj)
  93. {
  94. case IntPtr val:
  95. return val == _handle;
  96. }
  97. return base.Equals(obj);
  98. }
  99. }
  100. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。

Contributors (1)