You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

OpDefLibrary.cs 8.5 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. using System;
  2. using System.Collections.Generic;
  3. using System.ComponentModel;
  4. using System.Dynamic;
  5. using System.IO;
  6. using System.Linq;
  7. using System.Runtime.InteropServices;
  8. using System.Text;
  9. using static Tensorflow.OpDef.Types;
  10. namespace Tensorflow
  11. {
  12. public class OpDefLibrary
  13. {
  14. public Operation _apply_op_helper(string op_type_name, string name = "", dynamic args = null)
  15. {
  16. var keywords = ConvertToDict(args);
  17. var g = ops.get_default_graph();
  18. var op_def = g.GetOpDef(op_type_name);
  19. // Default name if not specified.
  20. if (String.IsNullOrEmpty(name))
  21. name = op_type_name;
  22. // Check for deprecation
  23. if (op_def.Deprecation != null && op_def.Deprecation.Version > 0)
  24. {
  25. }
  26. var default_type_attr_map = new Dictionary<string, object>();
  27. foreach (var attr_def in op_def.Attr)
  28. {
  29. if (attr_def.Type != "type") continue;
  30. var key = attr_def.Name;
  31. if (attr_def.DefaultValue != null)
  32. {
  33. default_type_attr_map[key] = attr_def.DefaultValue.Type;
  34. }
  35. }
  36. var attrs = new Dictionary<string, object>();
  37. var inputs = new List<Tensor>();
  38. var input_types = new List<TF_DataType>();
  39. var base_types = new List<TF_DataType>();
  40. Operation op = null;
  41. Python.with<ops.name_scope>(new ops.name_scope(name), scope =>
  42. {
  43. // Perform input type inference
  44. foreach (var input_arg in op_def.InputArg)
  45. {
  46. var input_name = input_arg.Name;
  47. var values = keywords[input_name];
  48. // Goals:
  49. // * Convert values to Tensors if it contains constants.
  50. // * Verify that values is a list if that matches the input_arg's
  51. // type.
  52. // * If the input_arg's type is determined by attrs, either set
  53. // those attrs and validate those attr values are legal (if
  54. // they have not yet been set) or validate the input matches
  55. // the type indicated by the attrs (if they have already been
  56. // inferred via an earlier input).
  57. // * If the input_arg has an explicit type, make sure the input
  58. // conforms.
  59. if (_IsListParameter(input_arg))
  60. {
  61. DataType dtype = DataType.DtInvalid;
  62. DataType default_dtype = DataType.DtInvalid;
  63. if (!_IsListValue(values))
  64. throw new TypeError($"Expected list for '{input_name}' argument to '{op_type_name}' Op, not {values}.");
  65. if(input_arg.Type != DataType.DtInvalid)
  66. {
  67. dtype = input_arg.Type;
  68. }
  69. else if (!String.IsNullOrEmpty(input_arg.NumberAttr))
  70. {
  71. }
  72. if(input_arg.IsRef && dtype != DataType.DtInvalid)
  73. dtype = dtype.as_base_dtype();
  74. values = ops.internal_convert_n_to_tensor(values, name: input_arg.Name, dtype: dtype, preferred_dtype: default_dtype, as_ref: input_arg.IsRef);
  75. inputs.AddRange(values as Tensor[]);
  76. }
  77. else
  78. {
  79. if (!(values is Tensor))
  80. {
  81. keywords[input_name] = constant_op.constant(values, input_name);
  82. }
  83. if (keywords[input_name] is Tensor value)
  84. {
  85. if (keywords.ContainsKey(input_name))
  86. {
  87. inputs.Add(value);
  88. }
  89. if (!String.IsNullOrEmpty(input_arg.TypeAttr))
  90. {
  91. attrs[input_arg.TypeAttr] = value.dtype;
  92. }
  93. values = new Tensor[] { value };
  94. }
  95. }
  96. base_types.AddRange((values as Tensor[]).Select(x => x.dtype.as_base_dtype()));
  97. input_types.AddRange(base_types);
  98. }
  99. // Process remaining attrs
  100. foreach (var attr in op_def.Attr)
  101. {
  102. if (keywords.ContainsKey(attr.Name))
  103. {
  104. attrs[attr.Name] = keywords[attr.Name];
  105. }
  106. }
  107. // Convert attr values to AttrValue protos.
  108. var attr_protos = new Dictionary<string, AttrValue>();
  109. foreach (var attr_def in op_def.Attr)
  110. {
  111. var key = attr_def.Name;
  112. if (!attrs.ContainsKey(key))
  113. Console.WriteLine($"{key} not found in attr_def.");
  114. var value = attrs[key];
  115. var attr_value = new AttrValue();
  116. switch (attr_def.Type)
  117. {
  118. case "string":
  119. attr_value.S = Google.Protobuf.ByteString.CopyFromUtf8((string)value);
  120. break;
  121. case "type":
  122. attr_value.Type = _MakeType((TF_DataType)value, attr_def);
  123. break;
  124. case "bool":
  125. attr_value.B = (bool)value;
  126. break;
  127. case "shape":
  128. attr_value.Shape = value == null ?
  129. attr_def.DefaultValue.Shape :
  130. tensor_util.as_shape((long[])value);
  131. break;
  132. default:
  133. throw new InvalidDataException($"attr_def.Type {attr_def.Type}");
  134. }
  135. attr_protos[key] = attr_value;
  136. }
  137. // Determine output types (possibly using attrs)
  138. var output_types = new List<TF_DataType>();
  139. foreach (var arg in op_def.OutputArg)
  140. {
  141. if (!String.IsNullOrEmpty(arg.NumberAttr))
  142. {
  143. }
  144. else if (!String.IsNullOrEmpty(arg.TypeAttr))
  145. {
  146. output_types.Add((TF_DataType)attr_protos[arg.TypeAttr].Type);
  147. }
  148. }
  149. // Add Op to graph
  150. op = g.create_op(op_type_name, inputs, output_types.ToArray(),
  151. name: scope,
  152. input_types: input_types.ToArray(),
  153. attrs: attr_protos,
  154. op_def: op_def);
  155. });
  156. return op;
  157. }
  158. public DataType _MakeType(TF_DataType v, AttrDef attr_def)
  159. {
  160. return v.as_base_dtype().as_datatype_enum();
  161. }
  162. private bool _IsListParameter(ArgDef arg)
  163. {
  164. if (!String.IsNullOrEmpty(arg.NumberAttr))
  165. return true;
  166. else if (!String.IsNullOrEmpty(arg.TypeListAttr))
  167. return true;
  168. else
  169. return false;
  170. }
  171. private bool _IsListValue(object v)
  172. {
  173. switch (v)
  174. {
  175. case Tensor[] val:
  176. return true;
  177. default:
  178. return false;
  179. }
  180. }
  181. private Dictionary<string, object> ConvertToDict(dynamic dyn)
  182. {
  183. var dictionary = new Dictionary<string, object>();
  184. foreach (PropertyDescriptor propertyDescriptor in TypeDescriptor.GetProperties(dyn))
  185. {
  186. object obj = propertyDescriptor.GetValue(dyn);
  187. string name = propertyDescriptor.Name;
  188. // avoid .net keyword
  189. if (name == "_ref_")
  190. name = "ref";
  191. dictionary.Add(name, obj);
  192. }
  193. return dictionary;
  194. }
  195. }
  196. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。