You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ops.py.cs 4.9 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Runtime.InteropServices;
  4. using System.Text;
  5. using System.Threading;
  6. using Tensorflow;
  7. using node_def_pb2 = Tensorflow;
  8. using Google.Protobuf;
  9. using System.Linq;
  10. namespace Tensorflow
  11. {
  12. public partial class ops
  13. {
  14. public static void add_to_collection<T>(string name, T value)
  15. {
  16. var graph = tf.get_default_graph();
  17. graph.add_to_collection(name, value);
  18. }
  19. public static void add_to_collections<T>(List<string> names, T value)
  20. {
  21. var graph = tf.get_default_graph();
  22. graph.add_to_collections(names, value);
  23. }
  24. public static object get_collection(string key)
  25. {
  26. return get_default_graph().get_collection(key);
  27. }
  28. public static Graph get_default_graph()
  29. {
  30. return tf.Graph();
  31. }
  32. public static Graph _get_graph_from_inputs(List<Tensor> op_input_list, Graph graph = null)
  33. {
  34. foreach(var op_input in op_input_list)
  35. {
  36. // Determine if this is a valid graph_element.
  37. var graph_element = op_input;
  38. }
  39. return get_default_graph();
  40. }
  41. public static Tensor convert_to_tensor(object value, string name = "")
  42. {
  43. switch (value)
  44. {
  45. case Tensor val:
  46. return val;
  47. default:
  48. var nd = tensor_util.convert_to_numpy_ndarray(value);
  49. return tf.constant(nd, name);
  50. }
  51. }
  52. public static unsafe IntPtr _create_c_op(Graph graph, NodeDef node_def, List<Tensor> inputs)
  53. {
  54. var op_desc = graph.NewOperation(node_def.Op, node_def.Name);
  55. // Add inputs
  56. if(inputs != null)
  57. {
  58. foreach (var op_input in inputs)
  59. {
  60. bool isList = false;
  61. if (!isList)
  62. {
  63. c_api.TF_AddInput(op_desc, op_input._as_tf_output());
  64. }
  65. else
  66. {
  67. c_api.TF_AddInputList(op_desc, inputs.Select(x => x._as_tf_output()).ToArray(), inputs.Count);
  68. }
  69. }
  70. }
  71. var status = new Status();
  72. // Add control inputs
  73. // Add attrs
  74. foreach (var attr in node_def.Attr)
  75. {
  76. var bytes = attr.Value.ToByteArray();
  77. var proto = Marshal.AllocHGlobal(bytes.Length);
  78. Marshal.Copy(bytes, 0, proto, bytes.Length);
  79. c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: (uint)bytes.Length, status: status);
  80. status.Check(true);
  81. }
  82. var c_op = c_api.TF_FinishOperation(op_desc, status);
  83. status.Check(true);
  84. return c_op;
  85. }
  86. public static OpDef _get_op_def(Graph graph, string type)
  87. {
  88. return graph.GetOpDef(type);
  89. }
  90. public static NodeDef _NodeDef(string op_type, string name, string device = "", Dictionary<string, AttrValue> attrs = null)
  91. {
  92. var node_def = new node_def_pb2.NodeDef();
  93. node_def.Op = op_type;
  94. node_def.Name = name;
  95. foreach (var attr in attrs)
  96. {
  97. node_def.Attr.Add(attr.Key, attr.Value);
  98. }
  99. return node_def;
  100. }
  101. public static string _name_from_scope_name(string name)
  102. {
  103. if (name.EndsWith("/"))
  104. {
  105. return name.Substring(0, name.Length - 1);
  106. }
  107. else
  108. {
  109. return name;
  110. }
  111. }
  112. /// <summary>
  113. /// A context manager that lifts ops out of control-flow scopes and function-building graphs.
  114. /// </summary>
  115. /// <returns></returns>
  116. public static void init_scope()
  117. {
  118. // Retrieve the active name scope: entering an `init_scope` preserves
  119. // the name scope of the current context.
  120. var default_graph = get_default_graph();
  121. var scope = default_graph.get_name_scope();
  122. if (!String.IsNullOrEmpty(scope) && !scope.EndsWith("/"))
  123. // Names that end with trailing slashes are treated by `name_scope` as
  124. // absolute.
  125. scope += "/";
  126. // inner_device_stack = default_graph._device_function_stack
  127. // var outer_context = default_graph.as_default;
  128. var outer_graph = get_default_graph();
  129. // outer_device_stack = None
  130. }
  131. public static int uid()
  132. {
  133. return 1;
  134. }
  135. }
  136. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。