You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

OpDefLibrary.cs 14 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362
  1. using System;
  2. using System.Collections.Generic;
  3. using System.ComponentModel;
  4. using System.Dynamic;
  5. using System.IO;
  6. using System.Linq;
  7. using System.Runtime.InteropServices;
  8. using System.Text;
  9. using static Tensorflow.OpDef.Types;
  10. namespace Tensorflow
  11. {
  12. public class OpDefLibrary : Python
  13. {
  14. public Operation _apply_op_helper(string op_type_name, string name = null, object args = null)
  15. {
  16. Dictionary<string, object> keywords = ConvertToDict(args);
  17. var g = ops.get_default_graph();
  18. var op_def = g.GetOpDef(op_type_name);
  19. // Default name if not specified.
  20. if (String.IsNullOrEmpty(name))
  21. name = op_type_name;
  22. // Check for deprecation
  23. if (op_def.Deprecation != null && op_def.Deprecation.Version > 0)
  24. {
  25. }
  26. var default_type_attr_map = new Dictionary<string, object>();
  27. foreach (var attr_def in op_def.Attr)
  28. {
  29. if (attr_def.Type != "type") continue;
  30. var key = attr_def.Name;
  31. if (attr_def.DefaultValue != null)
  32. {
  33. default_type_attr_map[key] = attr_def.DefaultValue.Type;
  34. }
  35. }
  36. var attrs = new Dictionary<string, object>();
  37. var inputs = new List<Tensor>();
  38. var input_types = new List<TF_DataType>();
  39. object values = null;
  40. return with(ops.name_scope(name), scope =>
  41. {
  42. var inferred_from = new Dictionary<string, object>();
  43. var base_types = new List<TF_DataType>();
  44. var types = new List<TF_DataType>();
  45. // Perform input type inference
  46. foreach (var input_arg in op_def.InputArg)
  47. {
  48. var input_name = input_arg.Name;
  49. if (keywords.ContainsKey(input_name))
  50. values = keywords[input_name];
  51. else if (keywords.ContainsKey(input_name + "_"))
  52. {
  53. input_name += "_";
  54. values = keywords[input_name];
  55. }
  56. else
  57. throw new TypeError("No argument for input " + input_name);
  58. // Goals:
  59. // * Convert values to Tensors if it contains constants.
  60. // * Verify that values is a list if that matches the input_arg's
  61. // type.
  62. // * If the input_arg's type is determined by attrs, either set
  63. // those attrs and validate those attr values are legal (if
  64. // they have not yet been set) or validate the input matches
  65. // the type indicated by the attrs (if they have already been
  66. // inferred via an earlier input).
  67. // * If the input_arg has an explicit type, make sure the input
  68. // conforms.
  69. DataType dtype = DataType.DtInvalid;
  70. DataType default_dtype = DataType.DtInvalid;
  71. if (_IsListParameter(input_arg))
  72. {
  73. if (!_IsListValue(values))
  74. throw new TypeError($"Expected list for '{input_name}' argument to '{op_type_name}' Op, not {values}.");
  75. if(input_arg.Type != DataType.DtInvalid)
  76. dtype = input_arg.Type;
  77. else if (!String.IsNullOrEmpty(input_arg.NumberAttr))
  78. {
  79. if (attrs.ContainsKey(input_arg.TypeAttr))
  80. dtype = (DataType)attrs[input_arg.TypeAttr];
  81. else
  82. if (values is Tensor[] values1)
  83. dtype = values1[0].dtype.as_datatype_enum();
  84. if (dtype == DataType.DtInvalid && default_type_attr_map.ContainsKey(input_arg.TypeAttr))
  85. default_dtype = (DataType)default_type_attr_map[input_arg.TypeAttr];
  86. }
  87. if(input_arg.IsRef && dtype != DataType.DtInvalid)
  88. dtype = dtype.as_base_dtype();
  89. values = ops.internal_convert_n_to_tensor(values,
  90. name: input_arg.Name,
  91. dtype: dtype.as_tf_dtype(),
  92. preferred_dtype: default_dtype.as_tf_dtype(),
  93. as_ref: input_arg.IsRef);
  94. }
  95. else
  96. {
  97. if (input_arg.Type != DataType.DtInvalid)
  98. dtype = input_arg.Type;
  99. else if (attrs.ContainsKey(input_arg.TypeAttr))
  100. dtype = (DataType)attrs[input_arg.TypeAttr];
  101. else if (default_type_attr_map.ContainsKey(input_arg.TypeAttr))
  102. default_dtype = (DataType)default_type_attr_map[input_arg.TypeAttr];
  103. var value = ops.internal_convert_to_tensor(values,
  104. name: input_name,
  105. dtype: dtype.as_tf_dtype(),
  106. as_ref: input_arg.IsRef,
  107. preferred_dtype: default_dtype.as_tf_dtype());
  108. //if (!String.IsNullOrEmpty(input_arg.TypeAttr))
  109. //attrs[input_arg.TypeAttr] = values.dtype;
  110. values = new Tensor[] { value };
  111. }
  112. if (values is Tensor[] values2)
  113. {
  114. types = values2.Select(x => x.dtype).ToList();
  115. inputs.AddRange(values2);
  116. base_types = values2.Select(x => x.dtype.as_base_dtype()).ToList();
  117. }
  118. else throw new NotImplementedException("_IsListParameter");
  119. SetAttrs(op_type_name,
  120. input_arg,
  121. op_def,
  122. attrs,
  123. inferred_from,
  124. types,
  125. base_types,
  126. input_types,
  127. values);
  128. }
  129. // Process remaining attrs
  130. foreach (var attr in op_def.Attr)
  131. {
  132. if (keywords.ContainsKey(attr.Name))
  133. {
  134. attrs[attr.Name] = keywords[attr.Name];
  135. }
  136. }
  137. // Convert attr values to AttrValue protos.
  138. var attr_protos = new Dictionary<string, AttrValue>();
  139. foreach (var attr_def in op_def.Attr)
  140. {
  141. var key = attr_def.Name;
  142. var value = attrs[key];
  143. if (!attrs.ContainsKey(key))
  144. Console.WriteLine($"_apply_op_helper: key '{key}' is not found in '{op_def.Name}' operation's attr_def.");
  145. attr_protos[key] = SetAttrValue(op_def, attr_def, value);
  146. }
  147. attrs.Clear();
  148. // Determine output types (possibly using attrs)
  149. var output_types = new List<TF_DataType>();
  150. foreach (var arg in op_def.OutputArg)
  151. {
  152. types = new List<TF_DataType>();
  153. if (!string.IsNullOrEmpty(arg.NumberAttr))
  154. {
  155. }
  156. else if (!string.IsNullOrEmpty(arg.TypeAttr))
  157. {
  158. types = new List<TF_DataType>() { (TF_DataType)attr_protos[arg.TypeAttr].Type };
  159. }
  160. if (arg.IsRef)
  161. types = types.Select(x => x.as_ref()).ToList();
  162. output_types.AddRange(types);
  163. }
  164. // Add Op to graph
  165. var op = g.create_op(op_type_name,
  166. inputs.ToArray(),
  167. output_types.ToArray(),
  168. name: scope,
  169. input_types: input_types.ToArray(),
  170. attrs: attr_protos,
  171. op_def: op_def);
  172. return op;
  173. });
  174. }
  175. private void SetAttrs(string op_type_name,
  176. ArgDef input_arg,
  177. OpDef op_def,
  178. Dictionary<string, object> attrs,
  179. Dictionary<string, object> inferred_from,
  180. List<TF_DataType> types,
  181. List<TF_DataType> base_types,
  182. List<TF_DataType> input_types,
  183. dynamic values)
  184. {
  185. var input_name = input_arg.Name;
  186. if (!string.IsNullOrEmpty(input_arg.NumberAttr))
  187. {
  188. if (attrs.ContainsKey(input_arg.NumberAttr))
  189. {
  190. }
  191. else
  192. {
  193. attrs[input_arg.NumberAttr] = (values as Tensor[]).Length;
  194. inferred_from[input_arg.NumberAttr] = input_name;
  195. var num_attr = op_def.Attr.First(x => x.Name == input_arg.NumberAttr);
  196. if (num_attr.HasMinimum && (values as Tensor[]).Length < num_attr.Minimum)
  197. throw new ValueError($"List argument '{input_name}' to '{op_type_name}' Op with length {(values as Tensor[]).Length} shorter " +
  198. $"than minimum length {num_attr.Minimum}");
  199. }
  200. // All tensors must have the same base type.
  201. if (input_arg.Type != DataType.DtInvalid)
  202. {
  203. }
  204. else
  205. {
  206. attrs[input_arg.TypeAttr] = base_types[0];
  207. inferred_from[input_arg.TypeAttr] = input_name;
  208. var type_attr = op_def.Attr.First(x => x.Name == input_arg.TypeAttr);
  209. }
  210. }
  211. else if (!string.IsNullOrEmpty(input_arg.TypeAttr))
  212. {
  213. var attr_value = base_types[0];
  214. if (attrs.ContainsKey(input_arg.TypeAttr))
  215. {
  216. }
  217. else
  218. {
  219. attrs[input_arg.TypeAttr] = attr_value;
  220. inferred_from[input_arg.TypeAttr] = input_name;
  221. }
  222. }
  223. else if (!string.IsNullOrEmpty(input_arg.TypeListAttr))
  224. {
  225. var attr_value = base_types;
  226. if (attrs.ContainsKey(input_arg.TypeListAttr))
  227. {
  228. }
  229. else
  230. {
  231. attrs[input_arg.TypeListAttr] = attr_value;
  232. inferred_from[input_arg.TypeListAttr] = input_name;
  233. }
  234. }
  235. if (input_arg.IsRef)
  236. input_types.AddRange(types);
  237. else
  238. input_types.AddRange(base_types);
  239. }
  240. public DataType _MakeType(TF_DataType v, AttrDef attr_def)
  241. {
  242. return v.as_base_dtype().as_datatype_enum();
  243. }
  244. private AttrValue SetAttrValue(OpDef op_def, AttrDef attr_def, object value)
  245. {
  246. var attr_value = new AttrValue();
  247. if (attr_def.Type.StartsWith("list("))
  248. {
  249. if (attr_def.HasMinimum)
  250. ;
  251. attr_value.List = new AttrValue.Types.ListValue();
  252. }
  253. switch (attr_def.Type)
  254. {
  255. case "string":
  256. attr_value.S = Google.Protobuf.ByteString.CopyFromUtf8((string)value);
  257. break;
  258. case "type":
  259. attr_value.Type = _MakeType((TF_DataType)value, attr_def);
  260. break;
  261. case "list(type)":
  262. attr_value.List.Type.AddRange((value as IList<TF_DataType>).Select(x => _MakeType(x, attr_def)));
  263. break;
  264. case "list(int)":
  265. attr_value.List.I.AddRange((value as int[]).Select(x => Convert.ToInt64(x)));
  266. break;
  267. case "bool":
  268. attr_value.B = (bool)value;
  269. break;
  270. case "float":
  271. attr_value.F = (float)value;
  272. break;
  273. case "int":
  274. attr_value.I = (int)value;
  275. if (attr_def.HasMinimum && attr_value.I < attr_def.Minimum)
  276. throw new ValueError($"Attr '{attr_def.Name}' of '{op_def.Name}' Op passed {attr_value.I} less than minimum {attr_def.Minimum}.");
  277. break;
  278. case "shape":
  279. if (value == null && attr_def.DefaultValue != null)
  280. attr_value.Shape = attr_def.DefaultValue.Shape;
  281. if(value is TensorShape val1)
  282. attr_value.Shape = val1.as_proto();
  283. else if(value is long[] val2)
  284. attr_value.Shape = tensor_util.as_shape(val2);
  285. else if (value is int[] val3)
  286. attr_value.Shape = tensor_util.as_shape(val3);
  287. break;
  288. default:
  289. throw new TypeError($"SetAttrValue: can't not convert attr_def.Type '{attr_def.Type}' to protos.");
  290. }
  291. return attr_value;
  292. }
  293. private bool _IsListParameter(ArgDef arg)
  294. {
  295. if (!String.IsNullOrEmpty(arg.NumberAttr))
  296. return true;
  297. else if (!String.IsNullOrEmpty(arg.TypeListAttr))
  298. return true;
  299. else
  300. return false;
  301. }
  302. private bool _IsListValue(object v)
  303. {
  304. switch (v)
  305. {
  306. case Tensor[] val:
  307. return true;
  308. default:
  309. return false;
  310. }
  311. }
  312. }
  313. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。