You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

OpDefLibrary.cs 3.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. using System;
  2. using System.Collections.Generic;
  3. using System.IO;
  4. using System.Runtime.InteropServices;
  5. using System.Text;
  6. using static Tensorflow.OpDef.Types;
  7. namespace Tensorflow
  8. {
  9. public class OpDefLibrary
  10. {
  11. public unsafe Operation _apply_op_helper(string op_type_name, string name = "", Dictionary<string, object> keywords = null)
  12. {
  13. var g = ops.get_default_graph();
  14. var op_def = g.GetOpDef(op_type_name);
  15. if (String.IsNullOrEmpty(name))
  16. {
  17. name = op_type_name;
  18. }
  19. string scope = g.unique_name(name) + "/";
  20. foreach (var attr_def in op_def.Attr)
  21. {
  22. if (attr_def.Type != "type") continue;
  23. var key = attr_def.Name;
  24. }
  25. var attrs = new Dictionary<string, object>();
  26. // Perform input type inference
  27. var inputs = new List<Tensor>();
  28. var input_types = new List<TF_DataType>();
  29. foreach (var input_arg in op_def.InputArg)
  30. {
  31. var input_name = input_arg.Name;
  32. if (keywords.ContainsKey(input_name))
  33. {
  34. inputs.Add(keywords[input_name] as Tensor);
  35. }
  36. if (!String.IsNullOrEmpty(input_arg.TypeAttr))
  37. {
  38. attrs[input_arg.TypeAttr] = (keywords[input_name] as Tensor).dtype;
  39. }
  40. if (input_arg.IsRef)
  41. {
  42. }
  43. else
  44. {
  45. input_types.Add((keywords[input_name] as Tensor).dtype);
  46. }
  47. }
  48. // Process remaining attrs
  49. foreach (var attr in op_def.Attr)
  50. {
  51. if (keywords.ContainsKey(attr.Name))
  52. {
  53. attrs[attr.Name] = keywords[attr.Name];
  54. }
  55. }
  56. // Convert attr values to AttrValue protos.
  57. var attr_protos = new Dictionary<string, AttrValue>();
  58. foreach (var attr_def in op_def.Attr)
  59. {
  60. var key = attr_def.Name;
  61. var value = attrs[key];
  62. var attr_value = new AttrValue();
  63. switch (attr_def.Type)
  64. {
  65. case "type":
  66. attr_value.Type = _MakeType((TF_DataType)value, attr_def);
  67. break;
  68. case "bool":
  69. attr_value.B = (bool)value;
  70. break;
  71. case "shape":
  72. attr_value.Shape = new TensorShapeProto();
  73. break;
  74. }
  75. attr_protos[key] = attr_value;
  76. }
  77. // Determine output types (possibly using attrs)
  78. var output_types = new List<TF_DataType>();
  79. foreach (var arg in op_def.OutputArg)
  80. {
  81. if (!String.IsNullOrEmpty(arg.NumberAttr))
  82. {
  83. }
  84. else if (!String.IsNullOrEmpty(arg.TypeAttr))
  85. {
  86. output_types.Add((TF_DataType)attr_protos[arg.TypeAttr].Type);
  87. }
  88. }
  89. // Add Op to graph
  90. var op = g.create_op(op_type_name, inputs, output_types.ToArray(),
  91. name: scope,
  92. input_types: input_types.ToArray(),
  93. attrs: attr_protos,
  94. op_def: op_def);
  95. return op;
  96. }
  97. public DataType _MakeType(TF_DataType v, AttrDef attr_def)
  98. {
  99. return v.as_datatype_enum();
  100. }
  101. }
  102. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。