You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ops.py.cs 16 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Runtime.InteropServices;
  4. using System.Text;
  5. using System.Threading;
  6. using Tensorflow;
  7. using node_def_pb2 = Tensorflow;
  8. using Google.Protobuf;
  9. using System.Linq;
  10. using NumSharp.Core;
  11. using System.ComponentModel;
  12. namespace Tensorflow
  13. {
  14. public partial class ops
  15. {
  16. public static void add_to_collection<T>(string name, T value)
  17. {
  18. var graph = tf.get_default_graph();
  19. graph.add_to_collection(name, value);
  20. }
  21. public static void add_to_collections<T>(List<string> names, T value)
  22. {
  23. var graph = tf.get_default_graph();
  24. graph.add_to_collections(names, value);
  25. }
  26. /// <summary>
  27. /// Wrapper for `Graph.get_collection()` using the default graph.
  28. /// contains many standard names for collections.
  29. /// </summary>
  30. /// <param name="key">
  31. /// The key for the collection. For example, the `GraphKeys` class
  32. /// </param>
  33. /// <param name="scope"></param>
  34. /// <returns>
  35. /// The list of values in the collection with the given `name`, or
  36. /// an empty list if no value has been added to that collection. The
  37. /// list contains the values in the order under which they were
  38. /// collected.
  39. /// </returns>
  40. public static object get_collection(string key, string scope = "")
  41. {
  42. return get_default_graph().get_collection(key, scope);
  43. }
  44. public static Graph get_default_graph()
  45. {
  46. return tf.Graph();
  47. }
  48. public static Graph _get_graph_from_inputs(List<Tensor> op_input_list, Graph graph = null)
  49. {
  50. foreach(var op_input in op_input_list)
  51. {
  52. // Determine if this is a valid graph_element.
  53. var graph_element = op_input;
  54. }
  55. return get_default_graph();
  56. }
  57. /// <summary>
  58. /// Converts the given `value` to a `Tensor`.
  59. /// </summary>
  60. /// <param name="value"></param>
  61. /// <param name="dtype"></param>
  62. /// <param name="name"></param>
  63. /// <returns></returns>
  64. public static Tensor convert_to_tensor(object value, TF_DataType dtype = TF_DataType.DtInvalid, string name = "")
  65. {
  66. switch (value)
  67. {
  68. case Tensor val:
  69. return val;
  70. default:
  71. var nd = tensor_util.convert_to_numpy_ndarray(value);
  72. return constant_op.constant(nd, name);
  73. }
  74. }
  75. public static Tensor convert_to_tensor_or_composite(Tensor value, TF_DataType dtype = TF_DataType.DtInvalid, string name = "")
  76. {
  77. return internal_convert_to_tensor_or_composite(value: value, dtype: dtype, name: name, as_ref: false);
  78. }
  79. public static Tensor internal_convert_to_tensor_or_composite(Tensor value, TF_DataType dtype = TF_DataType.DtInvalid, string name = "", bool as_ref = false)
  80. {
  81. return internal_convert_to_tensor<Tensor>(value, dtype: dtype.as_datatype_enum(), name: name, as_ref: as_ref);
  82. }
  83. /// <summary>
  84. /// Wrapper for `Graph.control_dependencies()` using the default graph.
  85. /// </summary>
  86. /// <param name="control_inputs"></param>
  87. public static _ControlDependenciesController control_dependencies(Operation[] control_inputs)
  88. {
  89. return get_default_graph().control_dependencies(control_inputs);
  90. }
  91. /// <summary>
  92. /// Creates a TF_Operation.
  93. /// </summary>
  94. /// <param name="graph">a `Graph`.</param>
  95. /// <param name="node_def">`node_def_pb2.NodeDef` for the operation to create.</param>
  96. /// <param name="inputs">
  97. /// A list of `Tensor`s (corresponding to scalar inputs) and lists of
  98. /// `Tensor`s (corresponding to sequence inputs, e.g. "int64 * N",
  99. /// "list(int64)"). The length of the list should be equal to the number of
  100. /// inputs specified by this operation's op def.
  101. /// </param>
  102. /// <param name="control_inputs">A list of `Operation`s to set as control dependencies.</param>
  103. /// <returns>A wrapped TF_Operation*.</returns>
  104. public static IntPtr _create_c_op<T>(Graph graph, NodeDef node_def, T[] inputs, Operation[] control_inputs)
  105. {
  106. var op_desc = graph.NewOperation(node_def.Op, node_def.Name);
  107. // Add inputs
  108. if(inputs != null)
  109. {
  110. foreach (var op_input in inputs)
  111. {
  112. if (op_input is Tensor[] op_inputs)
  113. c_api.TF_AddInputList(op_desc, op_inputs.Select(x => x._as_tf_output()).ToArray(), op_inputs.Length);
  114. else if (op_input is Tensor op_input1)
  115. c_api.TF_AddInput(op_desc, op_input1._as_tf_output());
  116. else
  117. throw new NotImplementedException("_create_c_op");
  118. }
  119. }
  120. var status = new Status();
  121. // Add control inputs
  122. foreach (var control_input in control_inputs)
  123. c_api.TF_AddControlInput(op_desc, control_input);
  124. // Add attrs
  125. foreach (var attr in node_def.Attr)
  126. {
  127. var bytes = attr.Value.ToByteArray();
  128. var proto = Marshal.AllocHGlobal(bytes.Length);
  129. Marshal.Copy(bytes, 0, proto, bytes.Length);
  130. c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: (uint)bytes.Length, status: status);
  131. status.Check(true);
  132. }
  133. var c_op = c_api.TF_FinishOperation(op_desc, status);
  134. status.Check(true);
  135. return c_op;
  136. }
  137. public static OpDef _get_op_def(Graph graph, string type)
  138. {
  139. return graph.GetOpDef(type);
  140. }
  141. public static NodeDef _NodeDef(string op_type, string name, string device = "", Dictionary<string, AttrValue> attrs = null)
  142. {
  143. var node_def = new node_def_pb2.NodeDef();
  144. node_def.Op = op_type;
  145. node_def.Name = name;
  146. foreach (var attr in attrs)
  147. {
  148. node_def.Attr.Add(attr.Key, attr.Value);
  149. }
  150. return node_def;
  151. }
  152. public static string _name_from_scope_name(string name)
  153. {
  154. if (name.EndsWith("/"))
  155. {
  156. return name.Substring(0, name.Length - 1);
  157. }
  158. else
  159. {
  160. return name;
  161. }
  162. }
  163. /// <summary>
  164. /// A context manager that lifts ops out of control-flow scopes and function-building graphs.
  165. /// </summary>
  166. /// <returns></returns>
  167. public static void init_scope()
  168. {
  169. // Retrieve the active name scope: entering an `init_scope` preserves
  170. // the name scope of the current context.
  171. var default_graph = get_default_graph();
  172. var scope = default_graph.get_name_scope();
  173. if (!String.IsNullOrEmpty(scope) && !scope.EndsWith("/"))
  174. // Names that end with trailing slashes are treated by `name_scope` as
  175. // absolute.
  176. scope += "/";
  177. // inner_device_stack = default_graph._device_function_stack
  178. // var outer_context = default_graph.as_default;
  179. Python.with(ops.control_dependencies(null), delegate
  180. {
  181. var outer_graph = get_default_graph();
  182. // outer_device_stack = None
  183. });
  184. }
  185. private static int uid_number = 0;
  186. /// <summary>
  187. /// A unique (within this program execution) integer.
  188. /// Not thread safe
  189. /// </summary>
  190. /// <returns></returns>
  191. public static int uid()
  192. {
  193. return uid_number++;
  194. }
  195. public static void colocate_with(Operation op, bool ignore_existing = false)
  196. {
  197. _colocate_with_for_gradient(op, null, ignore_existing);
  198. }
  199. public static void colocate_with(Tensor tensor, bool ignore_existing = false)
  200. {
  201. _colocate_with_for_gradient(tensor.op, null, ignore_existing);
  202. }
  203. public static void _colocate_with_for_gradient(Operation op, string gradient_uid, bool ignore_existing = false)
  204. {
  205. var default_graph = get_default_graph();
  206. default_graph._colocate_with_for_gradient(op, gradient_uid, ignore_existing);
  207. }
  208. /// <summary>
  209. /// Uses the default session to evaluate one or more tensors.
  210. /// </summary>
  211. /// <param name="tensors">A single Tensor, or a list of Tensor objects.</param>
  212. /// <param name="feed_dict">
  213. /// A dictionary that maps Tensor objects (or tensor names) to lists,
  214. /// numpy ndarrays, TensorProtos, or strings.
  215. /// </param>
  216. /// <param name="graph">The graph in which the tensors are defined.</param>
  217. /// <param name="session">A different session to use to evaluate "tensors".</param>
  218. /// <returns>
  219. /// Either a single numpy ndarray if "tensors" is a single tensor; or a list
  220. /// of numpy ndarrays that each correspond to the respective element in
  221. /// "tensors".
  222. /// </returns>
  223. public static NDArray _eval_using_default_session(Tensor tensor, FeedItem[] feed_dict, Graph graph, Session session = null)
  224. {
  225. if (session == null)
  226. {
  227. session = get_default_session();
  228. if (session == null)
  229. throw new ValueError("Cannot evaluate tensor using `eval()`: No default " +
  230. "session is registered. Use `with " +
  231. "sess.as_default()` or pass an explicit session to " +
  232. "`eval(session=sess)`");
  233. if (session.graph != graph)
  234. throw new ValueError("Cannot use the default session to evaluate tensor: " +
  235. "the tensor's graph is different from the session's " +
  236. "graph. Pass an explicit session to " +
  237. "`eval(session=sess)`.");
  238. }
  239. else
  240. {
  241. if (session.graph != graph)
  242. throw new ValueError("Cannot use the default session to evaluate tensor: " +
  243. "the tensor's graph is different from the session's " +
  244. "graph. Pass an explicit session to " +
  245. "`eval(session=sess)`.");
  246. }
  247. return session.run(tensor, feed_dict);
  248. }
  249. /// <summary>
  250. /// Returns the default session for the current thread.
  251. /// </summary>
  252. /// <returns>The default `Session` being used in the current thread.</returns>
  253. public static Session get_default_session()
  254. {
  255. return tf.Session();
  256. }
  257. public static Func<Operation, Tensor, Tensor[]> get_gradient_function(Operation op)
  258. {
  259. if (op.inputs == null) return null;
  260. return (oper, out_grads) =>
  261. {
  262. Console.WriteLine($"get_gradient_function: {oper.type} '{oper.Name}'");
  263. switch (oper.type)
  264. {
  265. case "Add":
  266. var add = math_grad._AddGrad(oper, out_grads);
  267. return new Tensor[] { add.Item1, add.Item2 };
  268. case "Identity":
  269. var id = math_grad._IdGrad(oper, out_grads);
  270. return new Tensor[] { id };
  271. case "Mul":
  272. var mul = math_grad._MulGrad(oper, out_grads);
  273. return new Tensor[] { mul.Item1, mul.Item2 };
  274. case "Sum":
  275. var sum = math_grad._SumGrad(oper, out_grads);
  276. return new Tensor[] { sum.Item1, sum.Item2 };
  277. case "Sub":
  278. var sub = math_grad._SubGrad(oper, out_grads);
  279. return new Tensor[] { sub.Item1, sub.Item2 };
  280. case "Pow":
  281. var pow = math_grad._PowGrad(oper, out_grads);
  282. return new Tensor[] { pow.Item1, pow.Item2 };
  283. case "RealDiv":
  284. var realdiv = math_grad._RealDivGrad(oper, out_grads);
  285. return new Tensor[] { realdiv.Item1, realdiv.Item2 };
  286. default:
  287. throw new NotImplementedException($"get_gradient_function {oper.type}");
  288. }
  289. };
  290. }
  291. public static Tensor[] convert_n_to_tensor_or_indexed_slices(Tensor[] values, TF_DataType dtype = TF_DataType.DtInvalid, string name = "")
  292. {
  293. return internal_convert_n_to_tensor_or_indexed_slices(values, dtype: dtype, name: name);
  294. }
  295. public static Tensor convert_to_tensor_or_indexed_slices(Tensor value, TF_DataType dtype = TF_DataType.DtInvalid, string name = "")
  296. {
  297. return internal_convert_to_tensor_or_indexed_slices(value: value, dtype: dtype, name: name, as_ref: false);
  298. }
  299. public static Tensor internal_convert_to_tensor_or_indexed_slices(Tensor value, TF_DataType dtype = TF_DataType.DtInvalid, string name = "", bool as_ref = false)
  300. {
  301. return value;
  302. }
  303. public static Tensor[] internal_convert_n_to_tensor_or_indexed_slices(Tensor[] values, TF_DataType dtype = TF_DataType.DtInvalid, string name = "", bool as_ref = false)
  304. {
  305. var ret = new List<Tensor>();
  306. foreach(var (i, value) in Python.enumerate(values))
  307. {
  308. if (value == null)
  309. {
  310. ret.Add(value);
  311. }
  312. else
  313. {
  314. var n = string.IsNullOrEmpty(name) ? "" : $"{name}_{i}";
  315. ret.Add(internal_convert_to_tensor_or_indexed_slices(value, dtype: dtype, name: n, as_ref: as_ref));
  316. }
  317. }
  318. return ret.ToArray();
  319. }
  320. public static Tensor[] internal_convert_n_to_tensor<T>(T[] values, DataType dtype = DataType.DtInvalid,
  321. string name = "", DataType preferred_dtype = DataType.DtInvalid,
  322. bool as_ref = false)
  323. {
  324. var ret = new List<Tensor>();
  325. foreach((int i, T value) in Python.enumerate(values))
  326. {
  327. string n = string.IsNullOrEmpty(name) ? "" : $"{name}_{i}";
  328. ret.Add(internal_convert_to_tensor(value, dtype: dtype, name: n, as_ref: as_ref, preferred_dtype: preferred_dtype));
  329. }
  330. return ret.ToArray();
  331. }
  332. public static Tensor internal_convert_to_tensor<T>(T value, DataType dtype = DataType.DtInvalid,
  333. string name = "", DataType preferred_dtype = DataType.DtInvalid,
  334. bool as_ref = false)
  335. {
  336. switch (typeof(T).Name)
  337. {
  338. case "Tensor":
  339. return value as Tensor;
  340. case "String":
  341. return constant_op.constant(Convert.ToString(value), name);
  342. case "String[]":
  343. return constant_op.constant(value as string[], name);
  344. case "Int32":
  345. return constant_op.constant(Convert.ToInt32(value), name);
  346. case "Double":
  347. return constant_op.constant(Convert.ToDouble(value), name);
  348. case "RefVariable":
  349. return (value as RefVariable)._TensorConversionFunction(as_ref: as_ref);
  350. default:
  351. throw new NotImplementedException($"internal_convert_to_tensor: Can't convert {typeof(T).Name} to Tensor");
  352. }
  353. }
  354. }
  355. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。