You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Operation.cs 7.8 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. using Google.Protobuf.Collections;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Linq;
  5. using System.Runtime.InteropServices;
  6. using System.Text;
  7. namespace Tensorflow
  8. {
  9. public partial class Operation : IReturnTensorOrOperation
  10. {
  11. private readonly IntPtr _handle; // _c_op in python
  12. public Graph Graph { get; }
  13. public int _id => _id_value;
  14. private int _id_value;
  15. public string type => OpType;
  16. private Status status = new Status();
  17. public string Name => c_api.StringPiece(c_api.TF_OperationName(_handle));
  18. public string OpType => c_api.StringPiece(c_api.TF_OperationOpType(_handle));
  19. public string Device => c_api.StringPiece(c_api.TF_OperationDevice(_handle));
  20. private NodeDef _node_def;
  21. public NodeDef node_def
  22. {
  23. get
  24. {
  25. if(_node_def == null)
  26. _node_def = GetNodeDef();
  27. return _node_def;
  28. }
  29. }
  30. public Operation(IntPtr handle)
  31. {
  32. if (handle == IntPtr.Zero)
  33. return;
  34. _handle = handle;
  35. this.Graph = ops.get_default_graph();
  36. _outputs = new Tensor[NumOutputs];
  37. for (int i = 0; i < NumOutputs; i++)
  38. _outputs[i] = new Tensor(this, i, OutputType(i));
  39. }
  40. public Operation(Graph g, string opType, string oper_name)
  41. {
  42. Graph = g;
  43. var desc = c_api.TF_NewOperation(g, opType, oper_name);
  44. c_api.TF_SetAttrType(desc, "dtype", TF_DataType.TF_INT32);
  45. c_api.TF_FinishOperation(desc, status);
  46. }
  47. /// <summary>
  48. /// Creates an `Operation`.
  49. /// </summary>
  50. /// <param name="node_def">`node_def_pb2.NodeDef`. `NodeDef` for the `Operation`.</param>
  51. /// <param name="g">`Graph`. The parent graph.</param>
  52. /// <param name="inputs">list of `Tensor` objects. The inputs to this `Operation`.</param>
  53. /// <param name="output_types">list of `DType` objects.</param>
  54. /// <param name="control_inputs">
  55. /// list of operations or tensors from which to have a
  56. /// control dependency.
  57. /// </param>
  58. /// <param name="input_types">
  59. /// List of `DType` objects representing the
  60. /// types of the tensors accepted by the `Operation`. By default
  61. /// uses `[x.dtype.base_dtype for x in inputs]`. Operations that expect
  62. /// reference-typed inputs must specify these explicitly.
  63. /// </param>
  64. /// <param name="original_op"></param>
  65. /// <param name="op_def"></param>
  66. public Operation(NodeDef node_def, Graph g, Tensor[] inputs = null, TF_DataType[] output_types = null, Operation[] control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null)
  67. {
  68. Graph = g;
  69. // Build the list of control inputs.
  70. var control_input_ops = new List<Operation>();
  71. if(control_inputs != null)
  72. {
  73. foreach(var c in control_inputs)
  74. {
  75. switch (c)
  76. {
  77. case Operation c1:
  78. control_input_ops.Add(c1);
  79. break;
  80. default:
  81. throw new NotImplementedException($"Control input must be an Operation, a Tensor, or IndexedSlices: {c}");
  82. }
  83. }
  84. }
  85. // This will be set by self.inputs.
  86. _id_value = Graph._next_id();
  87. if(op_def == null)
  88. op_def = g.GetOpDef(node_def.Op);
  89. var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr);
  90. _handle = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray());
  91. // Initialize self._outputs.
  92. output_types = new TF_DataType[NumOutputs];
  93. for (int i = 0; i < NumOutputs; i++)
  94. output_types[i] = OutputType(i);
  95. _outputs = new Tensor[NumOutputs];
  96. for (int i = 0; i < NumOutputs; i++)
  97. _outputs[i] = new Tensor(this, i, OutputType(i));
  98. Graph._add_op(this);
  99. if (_handle != IntPtr.Zero)
  100. _control_flow_post_processing();
  101. }
  102. private object[] _reconstruct_sequence_inputs(OpDef op_def, Tensor[] inputs, MapField<string, AttrValue> attrs)
  103. {
  104. var grouped_inputs = new List<object>();
  105. int i = 0;
  106. int input_len = 0;
  107. bool is_sequence = false;
  108. foreach (var input_arg in op_def.InputArg)
  109. {
  110. if (!string.IsNullOrEmpty(input_arg.NumberAttr))
  111. {
  112. input_len = (int)attrs[input_arg.NumberAttr].I;
  113. is_sequence = true;
  114. }
  115. else if (!string.IsNullOrEmpty(input_arg.TypeListAttr))
  116. {
  117. input_len = attrs[input_arg.TypeListAttr].List.Type.Count;
  118. is_sequence = true;
  119. }
  120. else
  121. {
  122. input_len = 1;
  123. is_sequence = false;
  124. }
  125. if (is_sequence)
  126. grouped_inputs.Add(inputs.Skip(i).Take(input_len).ToArray());
  127. else
  128. grouped_inputs.Add(inputs[i]);
  129. i += input_len;
  130. }
  131. return grouped_inputs.ToArray();
  132. }
  133. public object get_attr<T>(string name)
  134. {
  135. AttrValue x = null;
  136. using (var buf = new Buffer())
  137. {
  138. c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status);
  139. status.Check(true);
  140. x = AttrValue.Parser.ParseFrom(buf);
  141. }
  142. switch (name)
  143. {
  144. case "T":
  145. case "dtype":
  146. return x.Type;
  147. case "shape":
  148. return x.Shape;
  149. default:
  150. switch (typeof(T).Name)
  151. {
  152. case "Boolean":
  153. return x.B;
  154. case "String":
  155. return x.S;
  156. default:
  157. throw new NotImplementedException($"Unsupported field type in {x.ToString()}");
  158. }
  159. }
  160. }
  161. public TF_AttrMetadata GetAttributeMetadata(string attr_name, Status s)
  162. {
  163. return c_api.TF_OperationGetAttrMetadata(_handle, attr_name, s);
  164. }
  165. private NodeDef GetNodeDef()
  166. {
  167. using (var s = new Status())
  168. using (var buffer = new Buffer())
  169. {
  170. c_api.TF_OperationToNodeDef(_handle, buffer, s);
  171. s.Check();
  172. return NodeDef.Parser.ParseFrom(buffer);
  173. }
  174. }
  175. public override string ToString()
  176. {
  177. return _handle == IntPtr.Zero ? "Undefined" : $"'{Name}' type={OpType}";
  178. }
  179. public static implicit operator Operation(IntPtr handle) => new Operation(handle);
  180. public static implicit operator IntPtr(Operation op) => op._handle;
  181. public override bool Equals(object obj)
  182. {
  183. switch (obj)
  184. {
  185. case IntPtr val:
  186. return val == _handle;
  187. case Operation val:
  188. return val._handle == _handle;
  189. }
  190. return base.Equals(obj);
  191. }
  192. }
  193. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。