You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Operation.cs 17 kB

6 years ago
6 years ago
6 years ago
6 years ago
4 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445
  1. /*****************************************************************************
  2. Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. ******************************************************************************/
  13. using Tensorflow.NumPy;
  14. using System;
  15. using System.Collections.Generic;
  16. using System.Linq;
  17. using Tensorflow.Util;
  18. using static Tensorflow.Binding;
  19. using Google.Protobuf;
  20. using Google.Protobuf.WellKnownTypes;
  21. using System.Diagnostics;
  22. namespace Tensorflow
  23. {
  24. /// <summary>
  25. /// Represents a graph node that performs computation on tensors.
  26. ///
  27. /// An `Operation` is a node in a TensorFlow `Graph` that takes zero or
  28. /// more `Tensor` objects as input, and produces zero or more `Tensor`
  29. /// objects as output. Objects of type `Operation` are created by
  30. /// calling an op constructor(such as `tf.matmul`)
  31. /// or `tf.Graph.create_op`.
  32. ///
  33. /// For example `c = tf.matmul(a, b)` creates an `Operation` of type
  34. /// "MatMul" that takes tensors `a` and `b` as input, and produces `c`
  35. /// as output.
  36. ///
  37. /// After the graph has been launched in a session, an `Operation` can
  38. /// be executed by passing it to
  39. /// `tf.Session.run`.
  40. /// `op.run()` is a shortcut for calling `tf.get_default_session().run(op)`.
  41. /// </summary>
  42. public partial class Operation : ITensorOrOperation
  43. {
  44. protected IntPtr _handle; // _c_op in python
  45. protected Graph _graph;
  46. internal Func<Operation, object[], Tensor[]> _gradient_function;
  47. public string type => OpType;
  48. public Graph graph => _graph;
  49. public int _id => _id_value;
  50. public int _id_value { get; set; }
  51. public Operation op => this;
  52. public TF_DataType dtype => output.dtype;
  53. public virtual string name => _handle == IntPtr.Zero ? "" : c_api.StringPiece(c_api.TF_OperationName(_handle));
  54. public string OpType => _handle == IntPtr.Zero ? "" : c_api.StringPiece(c_api.TF_OperationOpType(_handle));
  55. public string Device => _handle == IntPtr.Zero ? "" : c_api.StringPiece(c_api.TF_OperationDevice(_handle));
  56. //private OperationDescription _op_desc;
  57. public NodeDef node_def => GetNodeDef();
  58. protected Operation() { }
  59. public Operation(IntPtr handle, Graph g = null)
  60. {
  61. if (handle == IntPtr.Zero)
  62. return;
  63. _handle = handle;
  64. _graph = g ?? ops.get_default_graph();
  65. _outputs = new Tensor[NumOutputs];
  66. for (int i = 0; i < NumOutputs; i++)
  67. _outputs[i] = new Tensor(this, i, OutputType(i));
  68. // Dict mapping op name to file and line information for op colocation
  69. // context managers.
  70. _control_flow_context = _graph._get_control_flow_context();
  71. // Note: _control_flow_post_processing() must not be called here, the caller is responsible for calling it when using this constructor.
  72. }
  73. /*public Operation(Graph g, string opType, string oper_name)
  74. {
  75. _graph = g;
  76. var _operDesc = c_api.TF_NewOperation(g, opType, oper_name);
  77. c_api.TF_SetAttrType(_operDesc, "dtype", TF_DataType.TF_INT32);
  78. lock (Locks.ProcessWide)
  79. using (var status = new Status())
  80. {
  81. _handle = c_api.TF_FinishOperation(_operDesc, status);
  82. status.Check(true);
  83. }
  84. // Dict mapping op name to file and line information for op colocation
  85. // context managers.
  86. _control_flow_context = graph._get_control_flow_context();
  87. }*/
  88. /// <summary>
  89. /// Creates an `Operation`.
  90. /// </summary>
  91. /// <param name="node_def">`node_def_pb2.NodeDef`. `NodeDef` for the `Operation`.</param>
  92. /// <param name="g">`Graph`. The parent graph.</param>
  93. /// <param name="inputs">list of `Tensor` objects. The inputs to this `Operation`.</param>
  94. /// <param name="output_types">list of `DType` objects.</param>
  95. /// <param name="control_inputs">
  96. /// list of operations or tensors from which to have a
  97. /// control dependency.
  98. /// </param>
  99. /// <param name="input_types">
  100. /// List of `DType` objects representing the
  101. /// types of the tensors accepted by the `Operation`. By default
  102. /// uses `[x.dtype.base_dtype for x in inputs]`. Operations that expect
  103. /// reference-typed inputs must specify these explicitly.
  104. /// </param>
  105. /// <param name="original_op"></param>
  106. /// <param name="op_def"></param>
  107. public Operation(NodeDef node_def, Graph g, Tensor[] inputs = null, TF_DataType[] output_types = null, ITensorOrOperation[] control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null)
  108. {
  109. _graph = g;
  110. // Build the list of control inputs.
  111. var control_input_ops = new List<Operation>();
  112. if (control_inputs != null)
  113. {
  114. foreach (var c in control_inputs)
  115. {
  116. switch (c)
  117. {
  118. case Operation c1:
  119. control_input_ops.Add(c1);
  120. break;
  121. case Tensor tensor:
  122. control_input_ops.Add(tensor.op);
  123. break;
  124. // TODO: IndexedSlices don't yet exist, but once they do, this needs to be uncommented
  125. //case IndexedSlices islices:
  126. // control_input_ops.Add(islices.op);
  127. // break;
  128. default:
  129. throw new NotImplementedException($"Control input must be an Operation, a Tensor, or IndexedSlices: {c}");
  130. }
  131. }
  132. }
  133. _id_value = _graph._next_id();
  134. // Dict mapping op name to file and line information for op colocation
  135. // context managers.
  136. _control_flow_context = graph._get_control_flow_context();
  137. // This will be set by self.inputs.
  138. if (op_def == null)
  139. op_def = g.GetOpDef(node_def.Op);
  140. (_handle, _) = ops._create_c_op(g, node_def, inputs, control_input_ops.ToArray(), op_def);
  141. // Initialize self._outputs.
  142. output_types = new TF_DataType[NumOutputs];
  143. for (int i = 0; i < NumOutputs; i++)
  144. output_types[i] = OutputType(i);
  145. _outputs = new Tensor[NumOutputs];
  146. for (int i = 0; i < NumOutputs; i++)
  147. _outputs[i] = new Tensor(this, i, output_types[i]);
  148. graph._add_op(this);
  149. if (_handle != IntPtr.Zero)
  150. _control_flow_post_processing();
  151. }
  152. public void run(FeedItem[] feed_dict = null, Session session = null)
  153. {
  154. ops._run_using_default_session(this, feed_dict, graph, session);
  155. }
  156. public virtual T get_attr<T>(string name)
  157. {
  158. if (typeof(T).IsValueType)
  159. {
  160. return (T)Convert.ChangeType(get_attr(name), typeof(T));
  161. }
  162. else
  163. {
  164. return (T)get_attr(name);
  165. }
  166. }
  167. internal unsafe TF_DataType _get_attr_type(string name)
  168. {
  169. Status status = new();
  170. TF_DataType result;
  171. c_api.TF_OperationGetAttrType(_handle, name, new IntPtr(&result), status);
  172. status.Check(true);
  173. return result;
  174. }
  175. internal unsafe int _get_attr_int(string name)
  176. {
  177. Status status = new();
  178. int result;
  179. c_api.TF_OperationGetAttrInt(_handle, name, new IntPtr(&result), status);
  180. status.Check(true);
  181. return result;
  182. }
  183. internal unsafe bool _get_attr_bool(string name)
  184. {
  185. Status status = new();
  186. bool result;
  187. c_api.TF_OperationGetAttrBool(_handle, name, new IntPtr(&result), status);
  188. status.Check(true);
  189. return result;
  190. }
  191. public virtual T[] get_attr_list<T>(string name)
  192. {
  193. if (tf.executing_eagerly())
  194. return (T[])get_attr(name);
  195. var buf = new Buffer();
  196. c_api.TF_OperationGetAttrValueProto(_handle, name, buf, tf.Status);
  197. tf.Status.Check(true);
  198. var x = AttrValue.Parser.ParseFrom(buf.ToArray());
  199. string oneof_value = x.ValueCase.ToString();
  200. if (string.IsNullOrEmpty(oneof_value))
  201. return null;
  202. switch (typeof(T).Name)
  203. {
  204. case nameof(Int32):
  205. return x.List.I.Select(x => (T)Convert.ChangeType(x, typeof(T))).ToArray();
  206. case nameof(Int64):
  207. return x.List.I.Select(x => (T)Convert.ChangeType(x, typeof(T))).ToArray();
  208. default:
  209. return null;
  210. }
  211. }
  212. public virtual object get_attr(string name)
  213. {
  214. var buf = new Buffer();
  215. Status status = new();
  216. c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status);
  217. status.Check(true);
  218. var tf_buffer = c_api.TF_GetBuffer(buf);
  219. var x = AttrValue.Parser.ParseFrom(tf_buffer.AsSpan<byte>());
  220. var oneof_value = x.ValueCase;
  221. if (oneof_value == AttrValue.ValueOneofCase.None)
  222. return new object[0];
  223. if(oneof_value == AttrValue.ValueOneofCase.List)
  224. {
  225. if (x.List.S is not null && x.List.S.Count > 0)
  226. {
  227. return x.List.S.Select(x => x.ToStringUtf8()).ToArray();
  228. }
  229. else if (x.List.I is not null && x.List.I.Count > 0)
  230. {
  231. return x.List.I.ToArray();
  232. }
  233. else if (x.List.F is not null && x.List.F.Count > 0)
  234. {
  235. return x.List.F.ToArray();
  236. }
  237. else if (x.List.B is not null && x.List.B.Count > 0)
  238. {
  239. return x.List.B.ToArray();
  240. }
  241. else if (x.List.Shape is not null && x.List.Shape.Count > 0)
  242. {
  243. return x.List.Shape.ToArray();
  244. }
  245. else if (x.List.Tensor is not null && x.List.Tensor.Count > 0)
  246. {
  247. return x.List.Tensor.ToArray();
  248. }
  249. else if (x.List.Func is not null && x.List.Func.Count > 0)
  250. {
  251. return x.List.Func.ToArray();
  252. }
  253. else if (x.List.Type is not null && x.List.Type.Count > 0)
  254. {
  255. return x.List.Type.Select(x => x.as_tf_dtype()).ToArray();
  256. }
  257. else
  258. {
  259. return null;
  260. }
  261. }
  262. if(oneof_value == AttrValue.ValueOneofCase.Type)
  263. {
  264. return dtypes.as_tf_dtype(x.Type);
  265. }
  266. return ProtoUtils.GetSingleAttrValue(x, oneof_value);
  267. }
  268. public TF_AttrMetadata GetAttributeMetadata(string attr_name, Status s)
  269. {
  270. return c_api.TF_OperationGetAttrMetadata(_handle, attr_name, s);
  271. }
  272. [Obsolete("The implementation is not complete.")]
  273. internal void _set_device_from_string(string device_str)
  274. {
  275. // TODO(Rinne): complete it with new C API `SetRequestedDevice`.
  276. //c_api.TF_SetDevice(_handle, device_str);
  277. }
  278. [Obsolete("The implementation is not complete.")]
  279. internal void _set_device(string device)
  280. {
  281. _set_device_from_string(device);
  282. }
  283. private NodeDef GetNodeDef()
  284. {
  285. var buffer = new Buffer();
  286. c_api.TF_OperationToNodeDef(_handle, buffer, tf.Status);
  287. tf.Status.Check(throwException: true);
  288. return NodeDef.Parser.ParseFrom(buffer.ToArray());
  289. }
  290. /// <summary>
  291. /// Update the input to this operation at the given index.
  292. ///
  293. /// NOTE: This is for TF internal use only.Please don't use it.
  294. /// </summary>
  295. /// <param name="index">the index of the input to update.</param>
  296. /// <param name="tensor"> the Tensor to be used as the input at the given index.</param>
  297. public void _update_input(int index, Tensor tensor)
  298. {
  299. _assert_same_graph(tensor);
  300. // var input = _tf_input(index);
  301. // var output = tensor._as_tf_output();
  302. // Reset cached inputs.
  303. _inputs_val = null;
  304. // _node_def = null;
  305. // after the c_api call next time _inputs is accessed
  306. // the updated inputs are reloaded from the c_api
  307. // lock (Locks.ProcessWide)
  308. // {
  309. // disable
  310. // c_api.TF_UpdateEdge(_graph, output, input, tf.Status.Handle);
  311. //var updated_inputs = inputs;
  312. // tf.Status.Check();
  313. // }
  314. }
  315. private void _assert_same_graph(Tensor tensor)
  316. {
  317. //TODO: implement
  318. }
  319. /// <summary>
  320. /// Create and return a new TF_Output for output_idx'th output of this op.
  321. /// </summary>
  322. public TF_Output _tf_output(int output_idx)
  323. {
  324. return new TF_Output(_handle, output_idx);
  325. }
  326. /// <summary>
  327. /// Create and return a new TF_Input for input_idx'th input of this op.
  328. /// </summary>
  329. public TF_Input _tf_input(int input_idx)
  330. {
  331. return new TF_Input(_handle, input_idx);
  332. }
  333. public NDArray numpy() => throw new NotImplementedException("");
  334. internal void _add_outputs(TF_DataType[] types, Shape[] shapes)
  335. {
  336. Debug.Assert(types.Length == shapes.Length);
  337. int orig_num_outputs = this.outputs.Length;
  338. var new_outputs = new List<Tensor>(_outputs);
  339. // Since the `_outputs` is defined as `Array`, when we add new output, we
  340. // have to create a new array, which brings some performance concerns.
  341. // In the future maybe the type of `outputs` should be reconsidered.
  342. for(int i = 0; i < types.Length; i++)
  343. {
  344. var t = new Tensor(this, orig_num_outputs + i, types[i]);
  345. t.shape = shapes[i];
  346. new_outputs.Add(t);
  347. }
  348. _outputs = new_outputs.ToArray();
  349. }
  350. internal void _set_func_attr(string attr_name, string func_name)
  351. {
  352. var func = new NameAttrList() { Name = func_name };
  353. _set_attr(attr_name, new AttrValue() { Func = func });
  354. }
  355. internal void _set_type_list_attr(string attr_name, DataType[] types)
  356. {
  357. if(types is null || types.Length == 0)
  358. {
  359. return;
  360. }
  361. var type_list = new AttrValue.Types.ListValue();
  362. type_list.Type.AddRange(types);
  363. _set_attr(attr_name, new AttrValue() { List = type_list });
  364. }
  365. internal void _set_attr(string attr_name, AttrValue attr_value)
  366. {
  367. var buffer = new Buffer(attr_value.ToByteArray());
  368. try
  369. {
  370. _set_attr_with_buf(attr_name, buffer);
  371. }
  372. finally
  373. {
  374. buffer.Release();
  375. }
  376. }
  377. internal void _set_attr_with_buf(string attr_name, Buffer attr_buf)
  378. {
  379. Status status = new();
  380. c_api.TFC_SetAttr(graph, _handle, attr_name, attr_buf, status);
  381. status.Check(true);
  382. }
  383. }
  384. }