You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ops.cs 3.4 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Runtime.InteropServices;
  4. using System.Text;
  5. using System.Threading;
  6. using Tensorflow;
  7. using node_def_pb2 = Tensorflow;
  8. using Google.Protobuf;
  9. using System.Linq;
  10. namespace Tensorflow
  11. {
  12. public static class ops
  13. {
  14. public static Graph get_default_graph()
  15. {
  16. return tf.Graph();
  17. }
  18. public static Tensor convert_to_tensor(object value, string name = "")
  19. {
  20. var nd = tensor_util.convert_to_numpy_ndarray(value);
  21. return tf.constant(nd, name);
  22. }
  23. public static unsafe IntPtr _create_c_op(Graph graph, NodeDef node_def, List<Tensor> inputs)
  24. {
  25. var op_desc = graph.NewOperation(node_def.Op, node_def.Name);
  26. // Add inputs
  27. if(inputs != null)
  28. {
  29. foreach (var op_input in inputs)
  30. {
  31. bool isList = false;
  32. if (!isList)
  33. {
  34. c_api.TF_AddInput(op_desc, op_input._as_tf_output());
  35. }
  36. else
  37. {
  38. c_api.TF_AddInputList(op_desc, inputs.Select(x => x._as_tf_output()).ToArray(), inputs.Count);
  39. }
  40. }
  41. }
  42. var status = new Status();
  43. // Add control inputs
  44. // Add attrs
  45. foreach (var attr in node_def.Attr)
  46. {
  47. var bytes = attr.Value.ToByteArray();
  48. var proto = Marshal.AllocHGlobal(bytes.Length);
  49. Marshal.Copy(bytes, 0, proto, bytes.Length);
  50. c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: (uint)bytes.Length, status: status);
  51. status.Check(true);
  52. }
  53. var c_op = c_api.TF_FinishOperation(op_desc, status);
  54. if (status.Code != TF_Code.TF_OK) throw new Exception(status.Message);
  55. return c_op;
  56. }
  57. public static OpDef _get_op_def(Graph graph, string type)
  58. {
  59. return graph.GetOpDef(type);
  60. }
  61. public static NodeDef _NodeDef(string op_type, string name, string device = "", Dictionary<string, AttrValue> attrs = null)
  62. {
  63. var node_def = new node_def_pb2.NodeDef();
  64. node_def.Op = op_type;
  65. node_def.Name = name;
  66. foreach (var attr in attrs)
  67. {
  68. node_def.Attr.Add(attr.Key, attr.Value);
  69. }
  70. return node_def;
  71. }
  72. public static string name_scope(string name, string default_name = "", object values = null)
  73. {
  74. string _name = "";
  75. if (String.IsNullOrEmpty(name))
  76. {
  77. _name = default_name;
  78. }
  79. var g = get_default_graph();
  80. var _name_scope = g.name_scope(_name);
  81. return _name_scope;
  82. }
  83. public static string _name_from_scope_name(string name)
  84. {
  85. if (name.EndsWith("/"))
  86. {
  87. return name.Substring(0, name.Length - 1);
  88. }
  89. else
  90. {
  91. return name;
  92. }
  93. }
  94. public static int uid()
  95. {
  96. return 1;
  97. }
  98. }
  99. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。