You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Operation.cs 12 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316
  1. using Google.Protobuf.Collections;
  2. //using Newtonsoft.Json;
  3. using System;
  4. using System.Collections.Generic;
  5. using System.Linq;
  6. using System.Runtime.InteropServices;
  7. using System.Text;
  8. namespace Tensorflow
  9. {
  10. /// <summary>
  11. /// Represents a graph node that performs computation on tensors.
  12. ///
  13. /// An `Operation` is a node in a TensorFlow `Graph` that takes zero or
  14. /// more `Tensor` objects as input, and produces zero or more `Tensor`
  15. /// objects as output. Objects of type `Operation` are created by
  16. /// calling an op constructor(such as `tf.matmul`)
  17. /// or `tf.Graph.create_op`.
  18. ///
  19. /// For example `c = tf.matmul(a, b)` creates an `Operation` of type
  20. /// "MatMul" that takes tensors `a` and `b` as input, and produces `c`
  21. /// as output.
  22. ///
  23. /// After the graph has been launched in a session, an `Operation` can
  24. /// be executed by passing it to
  25. /// `tf.Session.run`.
  26. /// `op.run()` is a shortcut for calling `tf.get_default_session().run(op)`.
  27. /// </summary>
  28. public partial class Operation : ITensorOrOperation
  29. {
  30. private readonly IntPtr _handle; // _c_op in python
  31. private readonly IntPtr _operDesc;
  32. private Graph _graph;
  33. //[JsonIgnore]
  34. public Graph graph => _graph;
  35. //[JsonIgnore]
  36. public int _id => _id_value;
  37. //[JsonIgnore]
  38. public int _id_value;
  39. public string type => OpType;
  40. //[JsonIgnore]
  41. public Operation op => this;
  42. public TF_DataType dtype => TF_DataType.DtInvalid;
  43. private Status status = new Status();
  44. public string name => c_api.StringPiece(c_api.TF_OperationName(_handle));
  45. public string OpType => c_api.StringPiece(c_api.TF_OperationOpType(_handle));
  46. public string Device => c_api.StringPiece(c_api.TF_OperationDevice(_handle));
  47. private NodeDef _node_def;
  48. public NodeDef node_def
  49. {
  50. get
  51. {
  52. if(_node_def == null)
  53. _node_def = GetNodeDef();
  54. return _node_def;
  55. }
  56. }
  57. public Operation(IntPtr handle, Graph g=null)
  58. {
  59. if (handle == IntPtr.Zero)
  60. return;
  61. _handle = handle;
  62. _graph = g ?? ops.get_default_graph();
  63. _outputs = new Tensor[NumOutputs];
  64. for (int i = 0; i < NumOutputs; i++)
  65. _outputs[i] = new Tensor(this, i, OutputType(i));
  66. // Dict mapping op name to file and line information for op colocation
  67. // context managers.
  68. _control_flow_context = graph._get_control_flow_context();
  69. // Note: _control_flow_post_processing() must not be called here, the caller is responsible for calling it when using this constructor.
  70. }
  71. public Operation(Graph g, string opType, string oper_name)
  72. {
  73. _graph = g;
  74. _operDesc = c_api.TF_NewOperation(g, opType, oper_name);
  75. c_api.TF_SetAttrType(_operDesc, "dtype", TF_DataType.TF_INT32);
  76. _handle = c_api.TF_FinishOperation(_operDesc, status);
  77. // Dict mapping op name to file and line information for op colocation
  78. // context managers.
  79. _control_flow_context = graph._get_control_flow_context();
  80. }
  81. /// <summary>
  82. /// Creates an `Operation`.
  83. /// </summary>
  84. /// <param name="node_def">`node_def_pb2.NodeDef`. `NodeDef` for the `Operation`.</param>
  85. /// <param name="g">`Graph`. The parent graph.</param>
  86. /// <param name="inputs">list of `Tensor` objects. The inputs to this `Operation`.</param>
  87. /// <param name="output_types">list of `DType` objects.</param>
  88. /// <param name="control_inputs">
  89. /// list of operations or tensors from which to have a
  90. /// control dependency.
  91. /// </param>
  92. /// <param name="input_types">
  93. /// List of `DType` objects representing the
  94. /// types of the tensors accepted by the `Operation`. By default
  95. /// uses `[x.dtype.base_dtype for x in inputs]`. Operations that expect
  96. /// reference-typed inputs must specify these explicitly.
  97. /// </param>
  98. /// <param name="original_op"></param>
  99. /// <param name="op_def"></param>
  100. public Operation(NodeDef node_def, Graph g, Tensor[] inputs = null, TF_DataType[] output_types = null, ITensorOrOperation[] control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null)
  101. {
  102. _graph = g;
  103. // Build the list of control inputs.
  104. var control_input_ops = new List<Operation>();
  105. if(control_inputs != null)
  106. {
  107. foreach(var c in control_inputs)
  108. {
  109. switch (c)
  110. {
  111. case Operation c1:
  112. control_input_ops.Add(c1);
  113. break;
  114. case Tensor tensor:
  115. control_input_ops.Add(tensor.op);
  116. break;
  117. // TODO: IndexedSlices don't yet exist, but once they do, this needs to be uncommented
  118. //case IndexedSlices islices:
  119. // control_input_ops.Add(islices.op);
  120. // break;
  121. default:
  122. throw new NotImplementedException($"Control input must be an Operation, a Tensor, or IndexedSlices: {c}");
  123. }
  124. }
  125. }
  126. // Dict mapping op name to file and line information for op colocation
  127. // context managers.
  128. _control_flow_context = graph._get_control_flow_context();
  129. // This will be set by self.inputs.
  130. if (op_def == null)
  131. op_def = g.GetOpDef(node_def.Op);
  132. var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr);
  133. (_handle, _operDesc) = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray());
  134. // Initialize self._outputs.
  135. output_types = new TF_DataType[NumOutputs];
  136. for (int i = 0; i < NumOutputs; i++)
  137. output_types[i] = OutputType(i);
  138. _outputs = new Tensor[NumOutputs];
  139. for (int i = 0; i < NumOutputs; i++)
  140. _outputs[i] = new Tensor(this, i, OutputType(i));
  141. graph._add_op(this);
  142. if (_handle != IntPtr.Zero)
  143. _control_flow_post_processing();
  144. }
  145. public void run(FeedItem[] feed_dict = null, Session session = null)
  146. {
  147. ops._run_using_default_session(this, feed_dict, graph, session);
  148. }
  149. private object[] _reconstruct_sequence_inputs(OpDef op_def, Tensor[] inputs, MapField<string, AttrValue> attrs)
  150. {
  151. var grouped_inputs = new List<object>();
  152. int i = 0;
  153. int input_len = 0;
  154. bool is_sequence = false;
  155. foreach (var input_arg in op_def.InputArg)
  156. {
  157. if (!string.IsNullOrEmpty(input_arg.NumberAttr))
  158. {
  159. input_len = (int)attrs[input_arg.NumberAttr].I;
  160. is_sequence = true;
  161. }
  162. else if (!string.IsNullOrEmpty(input_arg.TypeListAttr))
  163. {
  164. input_len = attrs[input_arg.TypeListAttr].List.Type.Count;
  165. is_sequence = true;
  166. }
  167. else
  168. {
  169. input_len = 1;
  170. is_sequence = false;
  171. }
  172. if (is_sequence)
  173. grouped_inputs.Add(inputs.Skip(i).Take(input_len).ToArray());
  174. else
  175. grouped_inputs.Add(inputs[i]);
  176. i += input_len;
  177. }
  178. return grouped_inputs.ToArray();
  179. }
  180. public object get_attr(string name)
  181. {
  182. AttrValue x = null;
  183. using (var buf = new Buffer())
  184. {
  185. c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status);
  186. status.Check(true);
  187. x = AttrValue.Parser.ParseFrom(buf);
  188. }
  189. string oneof_value = x.ValueCase.ToString();
  190. if (string.IsNullOrEmpty(oneof_value))
  191. return null;
  192. if(oneof_value == "list")
  193. throw new NotImplementedException($"Unsupported field type in {x.ToString()}");
  194. if (oneof_value == "type")
  195. return x.Type;
  196. object result = x.GetType().GetProperty(oneof_value).GetValue(x);
  197. if (result is Google.Protobuf.ByteString byteString)
  198. return byteString.ToStringUtf8();
  199. return result;
  200. }
  201. public TF_AttrMetadata GetAttributeMetadata(string attr_name, Status s)
  202. {
  203. return c_api.TF_OperationGetAttrMetadata(_handle, attr_name, s);
  204. }
  205. private NodeDef GetNodeDef()
  206. {
  207. using (var s = new Status())
  208. using (var buffer = new Buffer())
  209. {
  210. c_api.TF_OperationToNodeDef(_handle, buffer, s);
  211. s.Check();
  212. return NodeDef.Parser.ParseFrom(buffer);
  213. }
  214. }
  215. public override string ToString()
  216. {
  217. return _handle == IntPtr.Zero ? "tf.Operation Undefined" : $"tf.Operation '{name}' type={OpType}";
  218. }
  219. public static implicit operator Operation(IntPtr handle) => new Operation(handle);
  220. public static implicit operator IntPtr(Operation op) => op._handle;
  221. public override bool Equals(object obj)
  222. {
  223. switch (obj)
  224. {
  225. case IntPtr val:
  226. return val == _handle;
  227. case Operation val:
  228. return val._handle == _handle;
  229. }
  230. return base.Equals(obj);
  231. }
  232. /// <summary>
  233. /// Update the input to this operation at the given index.
  234. ///
  235. /// NOTE: This is for TF internal use only.Please don't use it.
  236. /// </summary>
  237. /// <param name="index">the index of the input to update.</param>
  238. /// <param name="tensor"> the Tensor to be used as the input at the given index.</param>
  239. public void _update_input(int index, Tensor tensor)
  240. {
  241. _assert_same_graph(tensor);
  242. var input = _tf_input(index);
  243. var output = tensor._as_tf_output();
  244. // Reset cached inputs.
  245. _inputs = null;// new InputList(new Tensor[] { tensor }); // is this right? original code: self._inputs_val=None
  246. // TODO: implement below code dependencies
  247. //c_api.UpdateEdge(_graph._c_graph, output, input);
  248. }
  249. private void _assert_same_graph(Tensor tensor)
  250. {
  251. //TODO: implement
  252. }
  253. /// <summary>
  254. /// Create and return a new TF_Output for output_idx'th output of this op.
  255. /// </summary>
  256. public TF_Output _tf_output(int output_idx)
  257. {
  258. var tf_output = new TF_Output(op, output_idx);
  259. return tf_output;
  260. }
  261. /// <summary>
  262. /// Create and return a new TF_Input for input_idx'th input of this op.
  263. /// </summary>
  264. public TF_Input _tf_input(int input_idx)
  265. {
  266. var tf_input = new TF_Input(op, input_idx);
  267. return tf_input;
  268. }
  269. }
  270. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。