You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Operation.cs 8.0 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago

  1. using Google.Protobuf.Collections;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Linq;
  5. using System.Runtime.InteropServices;
  6. using System.Text;
  7. namespace Tensorflow
  8. {
  9. public partial class Operation : ITensorOrOperation
  10. {
  11. private readonly IntPtr _handle; // _c_op in python
  12. public Graph graph { get; }
  13. public int _id => _id_value;
  14. public int _id_value;
  15. public string type => OpType;
  16. public Operation op => this;
  17. public TF_DataType dtype => TF_DataType.DtInvalid;
  18. private Status status = new Status();
  19. public string name => c_api.StringPiece(c_api.TF_OperationName(_handle));
  20. public string OpType => c_api.StringPiece(c_api.TF_OperationOpType(_handle));
  21. public string Device => c_api.StringPiece(c_api.TF_OperationDevice(_handle));
  22. private NodeDef _node_def;
  23. public NodeDef node_def
  24. {
  25. get
  26. {
  27. if(_node_def == null)
  28. _node_def = GetNodeDef();
  29. return _node_def;
  30. }
  31. }
  32. public Operation(IntPtr handle)
  33. {
  34. if (handle == IntPtr.Zero)
  35. return;
  36. _handle = handle;
  37. this.graph = ops.get_default_graph();
  38. _outputs = new Tensor[NumOutputs];
  39. for (int i = 0; i < NumOutputs; i++)
  40. _outputs[i] = new Tensor(this, i, OutputType(i));
  41. }
  42. public Operation(Graph g, string opType, string oper_name)
  43. {
  44. graph = g;
  45. var desc = c_api.TF_NewOperation(g, opType, oper_name);
  46. c_api.TF_SetAttrType(desc, "dtype", TF_DataType.TF_INT32);
  47. c_api.TF_FinishOperation(desc, status);
  48. }
  49. /// <summary>
  50. /// Creates an `Operation`.
  51. /// </summary>
  52. /// <param name="node_def">`node_def_pb2.NodeDef`. `NodeDef` for the `Operation`.</param>
  53. /// <param name="g">`Graph`. The parent graph.</param>
  54. /// <param name="inputs">list of `Tensor` objects. The inputs to this `Operation`.</param>
  55. /// <param name="output_types">list of `DType` objects.</param>
  56. /// <param name="control_inputs">
  57. /// list of operations or tensors from which to have a
  58. /// control dependency.
  59. /// </param>
  60. /// <param name="input_types">
  61. /// List of `DType` objects representing the
  62. /// types of the tensors accepted by the `Operation`. By default
  63. /// uses `[x.dtype.base_dtype for x in inputs]`. Operations that expect
  64. /// reference-typed inputs must specify these explicitly.
  65. /// </param>
  66. /// <param name="original_op"></param>
  67. /// <param name="op_def"></param>
  68. public Operation(NodeDef node_def, Graph g, Tensor[] inputs = null, TF_DataType[] output_types = null, ITensorOrOperation[] control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null)
  69. {
  70. graph = g;
  71. // Build the list of control inputs.
  72. var control_input_ops = new List<Operation>();
  73. if(control_inputs != null)
  74. {
  75. foreach(var c in control_inputs)
  76. {
  77. switch (c)
  78. {
  79. case Operation c1:
  80. control_input_ops.Add(c1);
  81. break;
  82. default:
  83. throw new NotImplementedException($"Control input must be an Operation, a Tensor, or IndexedSlices: {c}");
  84. }
  85. }
  86. }
  87. // This will be set by self.inputs.
  88. if(op_def == null)
  89. op_def = g.GetOpDef(node_def.Op);
  90. var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr);
  91. _handle = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray());
  92. // Initialize self._outputs.
  93. output_types = new TF_DataType[NumOutputs];
  94. for (int i = 0; i < NumOutputs; i++)
  95. output_types[i] = OutputType(i);
  96. _outputs = new Tensor[NumOutputs];
  97. for (int i = 0; i < NumOutputs; i++)
  98. _outputs[i] = new Tensor(this, i, OutputType(i));
  99. graph._add_op(this);
  100. if (_handle != IntPtr.Zero)
  101. _control_flow_post_processing();
  102. }
  103. public void run(FeedItem[] feed_dict = null, Session session = null)
  104. {
  105. ops._run_using_default_session(this, feed_dict, graph, session);
  106. }
  107. private object[] _reconstruct_sequence_inputs(OpDef op_def, Tensor[] inputs, MapField<string, AttrValue> attrs)
  108. {
  109. var grouped_inputs = new List<object>();
  110. int i = 0;
  111. int input_len = 0;
  112. bool is_sequence = false;
  113. foreach (var input_arg in op_def.InputArg)
  114. {
  115. if (!string.IsNullOrEmpty(input_arg.NumberAttr))
  116. {
  117. input_len = (int)attrs[input_arg.NumberAttr].I;
  118. is_sequence = true;
  119. }
  120. else if (!string.IsNullOrEmpty(input_arg.TypeListAttr))
  121. {
  122. input_len = attrs[input_arg.TypeListAttr].List.Type.Count;
  123. is_sequence = true;
  124. }
  125. else
  126. {
  127. input_len = 1;
  128. is_sequence = false;
  129. }
  130. if (is_sequence)
  131. grouped_inputs.Add(inputs.Skip(i).Take(input_len).ToArray());
  132. else
  133. grouped_inputs.Add(inputs[i]);
  134. i += input_len;
  135. }
  136. return grouped_inputs.ToArray();
  137. }
  138. public object get_attr<T>(string name)
  139. {
  140. AttrValue x = null;
  141. using (var buf = new Buffer())
  142. {
  143. c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status);
  144. status.Check(true);
  145. x = AttrValue.Parser.ParseFrom(buf);
  146. }
  147. switch (name)
  148. {
  149. case "T":
  150. case "dtype":
  151. return x.Type;
  152. case "shape":
  153. return x.Shape;
  154. default:
  155. switch (typeof(T).Name)
  156. {
  157. case "Boolean":
  158. return x.B;
  159. case "String":
  160. return x.S;
  161. default:
  162. throw new NotImplementedException($"Unsupported field type in {x.ToString()}");
  163. }
  164. }
  165. }
  166. public TF_AttrMetadata GetAttributeMetadata(string attr_name, Status s)
  167. {
  168. return c_api.TF_OperationGetAttrMetadata(_handle, attr_name, s);
  169. }
  170. private NodeDef GetNodeDef()
  171. {
  172. using (var s = new Status())
  173. using (var buffer = new Buffer())
  174. {
  175. c_api.TF_OperationToNodeDef(_handle, buffer, s);
  176. s.Check();
  177. return NodeDef.Parser.ParseFrom(buffer);
  178. }
  179. }
  180. public override string ToString()
  181. {
  182. return _handle == IntPtr.Zero ? "tf.Operation Undefined" : $"tf.Operation '{name}' type={OpType}";
  183. }
  184. public static implicit operator Operation(IntPtr handle) => new Operation(handle);
  185. public static implicit operator IntPtr(Operation op) => op._handle;
  186. public override bool Equals(object obj)
  187. {
  188. switch (obj)
  189. {
  190. case IntPtr val:
  191. return val == _handle;
  192. case Operation val:
  193. return val._handle == _handle;
  194. }
  195. return base.Equals(obj);
  196. }
  197. }
  198. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。