You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Operation.cs 6.4 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Runtime.InteropServices;
  4. using System.Text;
  5. namespace Tensorflow
  6. {
  7. public class Operation
  8. {
  9. private readonly IntPtr _handle;
  10. public Graph Graph { get; }
  11. public int _id => _id_value;
  12. private int _id_value;
  13. private Status status = new Status();
  14. public string Name => c_api.StringPiece(c_api.TF_OperationName(_handle));
  15. public string OpType => c_api.StringPiece(c_api.TF_OperationOpType(_handle));
  16. public string Device => c_api.StringPiece(c_api.TF_OperationDevice(_handle));
  17. public int NumOutputs => c_api.TF_OperationNumOutputs(_handle);
  18. public TF_DataType OutputType(int index) => c_api.TF_OperationOutputType(new TF_Output(_handle, index));
  19. public int OutputListLength(string name) => c_api.TF_OperationOutputListLength(_handle, name, status);
  20. public TF_Output Input(int index) => c_api.TF_OperationInput(new TF_Input(_handle, index));
  21. public TF_DataType InputType(int index) => c_api.TF_OperationInputType(new TF_Input(_handle, index));
  22. public int InputListLength(string name) => c_api.TF_OperationInputListLength(_handle, name, status);
  23. public int NumInputs => c_api.TF_OperationNumInputs(_handle);
  24. public int OutputNumConsumers(int index) => c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, index));
  25. public unsafe TF_Input[] OutputConsumers(int index, int max_consumers)
  26. {
  27. int size = Marshal.SizeOf<TF_Input>();
  28. var handle = (TF_Input*)Marshal.AllocHGlobal(size);
  29. int num = c_api.TF_OperationOutputConsumers(new TF_Output(_handle, index), handle, max_consumers);
  30. var consumers = new TF_Input[num];
  31. for(int i = 0; i < num; i++)
  32. {
  33. consumers[i] = new TF_Input((*handle).oper + i * size, (*handle).index);
  34. }
  35. return consumers;
  36. }
  37. public int NumControlInputs => c_api.TF_OperationNumControlInputs(_handle);
  38. public unsafe Operation[] GetControlInputs()
  39. {
  40. var control_inputs = new Operation[NumControlInputs];
  41. if(NumControlInputs > 0)
  42. {
  43. IntPtr control_input_handle = Marshal.AllocHGlobal(Marshal.SizeOf<IntPtr>() * NumControlInputs);
  44. c_api.TF_OperationGetControlInputs(_handle, control_input_handle, NumControlInputs);
  45. for (int i = 0; i < NumControlInputs; i++)
  46. {
  47. var handle = control_input_handle + Marshal.SizeOf<IntPtr>() * i;
  48. control_inputs[i] = new Operation(*(IntPtr*)handle);
  49. }
  50. }
  51. return control_inputs;
  52. }
  53. public int NumControlOutputs => c_api.TF_OperationNumControlOutputs(_handle);
  54. public unsafe Operation[] GetControlOutputs()
  55. {
  56. var control_outputs = new Operation[NumControlOutputs];
  57. if(NumControlOutputs > 0)
  58. {
  59. IntPtr control_output_handle = Marshal.AllocHGlobal(Marshal.SizeOf<IntPtr>() * NumControlOutputs);
  60. c_api.TF_OperationGetControlOutputs(_handle, control_output_handle, NumControlInputs);
  61. for (int i = 0; i < NumControlInputs; i++)
  62. {
  63. var handle = control_output_handle + Marshal.SizeOf<IntPtr>() * i;
  64. control_outputs[i] = new Operation(*(IntPtr*)handle);
  65. }
  66. }
  67. return control_outputs;
  68. }
  69. private Tensor[] _outputs;
  70. public Tensor[] outputs => _outputs;
  71. public Tensor[] inputs;
  72. public Operation(IntPtr handle)
  73. {
  74. if (handle == IntPtr.Zero)
  75. return;
  76. _handle = handle;
  77. }
  78. public Operation(Graph g, string opType, string oper_name)
  79. {
  80. Graph = g;
  81. var desc = c_api.TF_NewOperation(g, opType, oper_name);
  82. c_api.TF_SetAttrType(desc, "dtype", TF_DataType.TF_INT32);
  83. c_api.TF_FinishOperation(desc, status);
  84. }
  85. public Operation(NodeDef node_def, Graph g, List<Tensor> inputs = null, TF_DataType[] output_types = null, object control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null)
  86. {
  87. Graph = g;
  88. _id_value = Graph._next_id();
  89. if(op_def == null)
  90. op_def = g.GetOpDef(node_def.Op);
  91. _handle = ops._create_c_op(g, node_def, inputs);
  92. _outputs = new Tensor[NumOutputs];
  93. output_types = new TF_DataType[NumOutputs];
  94. for (int i = 0; i < NumOutputs; i++)
  95. {
  96. output_types[i] = OutputType(i);
  97. _outputs[i] = new Tensor(this, i, output_types[i]);
  98. }
  99. Graph._add_op(this);
  100. }
  101. public object get_attr(string name)
  102. {
  103. object ret = null;
  104. var fields = new string[] { "s", "i", "f", "b", "type", "shape", "tensor", "func" };
  105. switch (name)
  106. {
  107. case "dtype":
  108. ret = _outputs[0];
  109. break;
  110. case "shape":
  111. ret = new TensorShapeProto();
  112. break;
  113. }
  114. return ret;
  115. }
  116. public TF_AttrMetadata GetAttributeMetadata(string attr_name, Status s)
  117. {
  118. return c_api.TF_OperationGetAttrMetadata(_handle, attr_name, s);
  119. }
  120. public NodeDef GetNodeDef()
  121. {
  122. using (var s = new Status())
  123. using (var buffer = new Buffer())
  124. {
  125. c_api.TF_OperationToNodeDef(_handle, buffer, s);
  126. s.Check();
  127. return NodeDef.Parser.ParseFrom(buffer);
  128. }
  129. }
  130. public static implicit operator Operation(IntPtr handle) => new Operation(handle);
  131. public static implicit operator IntPtr(Operation op) => op._handle;
  132. public override bool Equals(object obj)
  133. {
  134. switch (obj)
  135. {
  136. case IntPtr val:
  137. return val == _handle;
  138. case Operation val:
  139. return val._handle == _handle;
  140. }
  141. return base.Equals(obj);
  142. }
  143. }
  144. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。